Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qwen2 model, optimize phi3 model, revise model loading strategy #46

Merged
merged 5 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ env_logger = "0.10.1"
tracing = "0.1.40"
range-checked = { git = "https://github.com/EricLBuehler/range-checked.git", version = "0.1.0" }
chrono = { version = "0.4.31", features = ["clock"] }
either = "1.9.0"
either = { version = "1.13.0", features = ["serde"] }
dirs = "5.0.1"
kernels = {path = "./kernels", version="0.1.0"}

Expand Down
36 changes: 29 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, BF16)
|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3** |✅|71 tks/s (7B)|
| #1 | **LLAMA/LLAMA2/LLaMa3** |✅|73 tks/s (7B)|
| #2 | Mistral |TBD|TBD|
| #3 | Phi (v1, v1.5, v2) |TBD|TBD|
| #4 | **Phi-3 (3.8B, 7B)** |✅|99 tks/s (3.8B)|
| #4 | **Phi-3 (3.8B, 7B)** |✅|102 tks/s (3.8B)|
| #5 | Yi |TBD|TBD|
| #6 | StableLM |TBD|TBD|
| #7 | BigCode/StarCode |TBD|TBD|
| #8 | ChatGLM |TBD|TBD|
| #9 | QWen |TBD|TBD|
| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|
| #10 | Google Gemma |TBD|TBD|
| #11 | Blip-large (Multimodal) |TBD|TBD|
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|
Expand All @@ -47,7 +47,12 @@ sudo apt install libssl-dev
sudo apt install pkg-config
git clone [email protected]:EricLBuehler/candle-vllm.git
cd candle-vllm
cargo run --release -- --port 2000 --weight-path /home/llama2_7b/ llama7b --repeat-last-n 64
cargo run --release -- --port 2000 --weight-path /home/llama2_7b/ llama --repeat-last-n 64
```

You may also run specific model using huggingface model-id, e.g.,
```
cargo run --release -- --port 2000 --model-id meta-llama/Llama-2-7b-chat-hf llama --repeat-last-n 64
```

### Step 2:
Expand Down Expand Up @@ -105,7 +110,7 @@ openai.api_key = "EMPTY"
openai.base_url = "http://localhost:2000/v1/"

completion = openai.chat.completions.create(
model="llama7b",
model="llama",
messages=[
{
"role": "user",
Expand All @@ -124,16 +129,33 @@ After the `candle-vllm` service is running, run the Python script and enjoy effi
## Usage Help
For general configuration help, run `cargo run -- --help`.

For model-specific help, run `cargo run -- --port 1234 <MODEL NAME> --help`
For model-specific help, run `cargo run -- --port 2000 <MODEL_TYPE> --help`

For local model weights, run `cargo run --release -- --port 2000 --weight-path /home/llama2_7b/ llama --repeat-last-n 64`, change the path when needed.

For local model weights, run `cargo run --release -- --port 2000 --weight-path /home/llama2_7b/ llama7b --repeat-last-n 64`, change the path when needed.
`MODEL_TYPE` = ["llama", "phi3", "qwen2"]

`WEIGHT_FILE_PATH` = Corresponding weight path for the given model type

```
cargo run --release --features gcu -- --port 2000 --weight-path <WEIGHT_FILE_PATH> <MODEL_TYPE> --repeat-last-n 64
```

or

`MODEL_ID` = Huggingface model id

```
cargo run --release --features gcu -- --port 2000 --model-id <MODEL_ID> <MODEL_TYPE> --repeat-last-n 64
```

For kvcache configuration, set `kvcache_mem_cpu` and `kvcache_mem_gpu`, default 4GB CPU memory and 4GB GPU memory for kvcache.

For chat history settings, set `record_conversation` to `true` to let candle-vllm remember chat history. By `default`, candle-vllm `does not` record chat history; instead, the client sends both the messages and the contextual history to candle-vllm. If record_conversation is set to `true`, the client sends only new chat messages to candle-vllm, and candle-vllm is responsible for recording the previous chat messages. However, this approach requires per-session chat recording, which is not yet implemented, so the default approach `record_conversation=false` is recommended.

For chat streaming, the `stream` flag in chat request need to be set to `True`.

You may revise `repetition_penalty` and `temperature` flag in chat request (http post).

## Report issue
Installing `candle-vllm` is as simple as the following steps. If you have any problems, please create an
Expand Down
29 changes: 2 additions & 27 deletions src/backend/cache.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use std::{collections::HashMap, iter::zip, ptr::NonNull};

use crate::{
backend::{dispatch_get_cuda_pointer, get_or_load_func, Conjoined},
backend::{get_or_load_func, Conjoined},
openai::responses::APIError,
try_api,
};
use candle_core::cuda_backend::CudaError;
use candle_core::cuda_backend::CudaStorageSlice;
use candle_core::{
cuda_backend::cudarc::driver::{CudaSlice, DevicePtr, LaunchAsync, LaunchConfig},
DType, Device, IndexOp, Storage, Tensor,
Device, IndexOp, Storage, Tensor,
};

use super::COPY_BLOCKS_KERNEL_NAME;
Expand Down Expand Up @@ -231,30 +230,6 @@ pub fn swap_blocks(
try_api!(dst_dev.htod_sync_copy_into(&src_slice[src_offset..src_offset+block_size_in_bytes], &mut dst_slice));
}
}
(Device::Cuda(src_dev), Device::Cpu) => {
// Pending on huggingface/candle#1467
todo!();
/*let (src_storage, src_layout) = src.storage_and_layout();
let (dst_storage, dst_layout) = dst.storage_mut_and_layout();
assert!(matches!(&*src_storage, Storage::Cuda(_)));
assert!(matches!(&*dst_storage, Storage::Cpu(_)));
let Storage::Cuda(src_storage) = &*src_storage else { unreachable!() };
let Storage::Cpu(dst_storage) = &*dst_storage else { unreachable!() };
let src_ptr = src_storage.as_cuda_slice::<u8>().map_err(APIError::from)?.device_ptr() + TryInto::<u64>::try_into(src_layout.start_offset()).unwrap();
let dst_slice: &[u8] = try_api!(dst_storage.as_slice());
let ptr = dst_slice.as_ptr() as *mut u8;
// Safety:
let dst_slice = unsafe { slice::from_raw_parts_mut(ptr, dst_slice.len()) };

for (src_block_number, dst_block_number) in block_mapping {
let src_offset: u64 = (src_block_number * block_size_in_bytes).try_into().unwrap();
let dst_offset: u64 = (dst_block_number * block_size_in_bytes).try_into().unwrap();
// u8s because we copy by bytes
let src_slice: CudaSlice<u8> = unsafe { src_dev.upgrade_device_ptr(src_ptr+src_offset, block_size_in_bytes) };

try_api!(src_dev.dtoh_sync_copy_into(&src_slice, dst_slice));
}*/
}
(src, dst) => {
return Err(APIError::new(format!("Tensors must be on either the GPU or CPU to swap,, got {src:?} (src) and {dst:?} (dst).")))
}
Expand Down
28 changes: 2 additions & 26 deletions src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod cache;
mod paged_attention;

const COPY_BLOCKS_KERNEL_NAME: &str = "copy_blocks_kernel";
const RESHAPE_AND_CACHE_KERNEL_NAME: &str = "reshape_and_cache_kernel";

pub fn get_or_load_func(
ptx_file: &'static str,
Expand Down Expand Up @@ -58,34 +57,11 @@ unsafe impl<'a, T, R> DeviceRepr for Conjoined<'a, T, R> {
}
}

fn dispatch_get_cuda_pointer(tensor: Tensor) -> u64 {
match tensor.dtype() {
DType::BF16 => get_cuda_pointer::<bf16>(tensor),
DType::F16 => get_cuda_pointer::<f16>(tensor),
DType::U8 => get_cuda_pointer::<u8>(tensor),
DType::U32 => get_cuda_pointer::<u32>(tensor),
DType::I64 => get_cuda_pointer::<i64>(tensor),
DType::F32 => get_cuda_pointer::<f32>(tensor),
DType::F64 => get_cuda_pointer::<f64>(tensor),
}
}

fn get_cuda_pointer<T: CudaDType>(tensor: Tensor) -> u64 {
match &*tensor.storage_and_layout().0 {
Storage::Cuda(cuda_storage) => *cuda_storage.as_cuda_slice::<T>().unwrap().device_ptr(),
other => panic!("Unsupported storage `{:?}`", other),
}
}

pub use cache::*;
use candle_core::{
cuda_backend::{
cudarc::driver::{CudaFunction, DevicePtr, DeviceRepr},
CudaDType,
},
CudaDevice, DType, Storage, Tensor,
cuda_backend::cudarc::driver::{CudaFunction, DeviceRepr},
CudaDevice, DType,
};
use half::{bf16, f16};
pub use paged_attention::*;
pub use std::ops::Deref;
use std::{
Expand Down
1 change: 0 additions & 1 deletion src/backend/paged_attention.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// use candle_core::{cuda_backend::cudarc::driver::CudaFunction, DType, Tensor};
use crate::openai::responses::APIError;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
Expand Down
66 changes: 33 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,22 @@ use openai::pipelines::{

#[derive(Debug, Subcommand)]
pub enum ModelSelected {
/// Select the llama7b model.
Llama7b {
/// Select the llama model (default llama2-7b).
Llama {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: usize,
},

/// Select the llama13b model.
Llama13b {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: usize,
},

/// Select the llama70b model.
Llama70b {
/// Select the phi3 model (default 3.8b).
Phi3 {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: usize,
},

/// Select the phi3 3.8b model.
Phi3 {
/// Select the qwen model (default 1.8b).
Qwen2 {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: usize,
Expand All @@ -41,43 +34,50 @@ pub enum ModelSelected {
impl ToString for ModelSelected {
fn to_string(&self) -> String {
match self {
ModelSelected::Llama7b { repeat_last_n: _ } => "llama7b".to_string(),
ModelSelected::Llama13b { repeat_last_n: _ } => "llama13b".to_string(),
ModelSelected::Llama70b { repeat_last_n: _ } => "llama70b".to_string(),
ModelSelected::Llama { repeat_last_n: _ } => "llama".to_string(),
ModelSelected::Phi3 { repeat_last_n: _ } => "phi3".to_string(),
ModelSelected::Qwen2 { repeat_last_n: _ } => "qwen2".to_string(),
}
}
}

pub fn get_model_loader<'a>(selected_model: ModelSelected) -> (Box<dyn ModelLoader<'a>>, String) {
pub fn get_model_loader<'a>(
selected_model: ModelSelected,
model_id: Option<String>,
) -> (Box<dyn ModelLoader<'a>>, String) {
match selected_model {
ModelSelected::Llama7b { repeat_last_n } => (
ModelSelected::Llama { repeat_last_n } => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"llama7b".to_string(),
"llama".to_string(),
)),
"meta-llama/Llama-2-7b-chat-hf".to_string(),
if model_id.is_some() {
model_id.unwrap()
} else {
"meta-llama/Llama-2-7b-chat-hf".to_string()
},
),
ModelSelected::Llama13b { repeat_last_n } => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"llama13b".to_string(),
)),
"meta-llama/Llama-2-13b-chat-hf".to_string(),
),
ModelSelected::Llama70b { repeat_last_n } => (
ModelSelected::Phi3 { repeat_last_n } => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"llama70b".to_string(),
"phi3".to_string(),
)),
"meta-llama/Llama-2-70b-chat-hf".to_string(),
if model_id.is_some() {
model_id.unwrap()
} else {
"microsoft/Phi-3-mini-4k-instruct".to_string()
},
),
ModelSelected::Phi3 { repeat_last_n } => (
ModelSelected::Qwen2 { repeat_last_n } => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"phi3".to_string(),
"qwen2".to_string(),
)),
"microsoft/Phi-3-mini-4k-instruct".to_string(),
if model_id.is_some() {
model_id.unwrap()
} else {
"Qwen/Qwen1.5-1.8B-Chat".to_string()
},
),
}
}
Expand Down
21 changes: 18 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use clap::Parser;
use futures::lock::Mutex;
use std::sync::Arc;
const SIZE_IN_MB: usize = 1024 * 1024;
use std::path::Path;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -47,8 +48,12 @@ struct Args {
#[arg(long, default_value_t = 32)]
block_size: usize,

/// if weight_path is passed, it will ignore the model_id
#[arg(long)]
model_id: Option<String>,

/// The folder name that contains safetensor weights and json files
/// (same structure as huggingface online)
/// (same structure as huggingface online), path must include last "/"
#[arg(long)]
weight_path: Option<String>,

Expand All @@ -74,13 +79,23 @@ struct Args {
#[actix_web::main]
async fn main() -> Result<(), APIError> {
let args = Args::parse();
let (loader, model_id) = get_model_loader(args.command);
let (loader, model_id) = get_model_loader(args.command, args.model_id.clone());
if args.model_id.is_none() {
println!("No model id specified, using the default model or specified in the weight_path!");
}

let paths = match &args.weight_path {
Some(path) => Box::new(DefaultModelPaths {
tokenizer_filename: (path.to_owned() + "tokenizer.json").into(),
config_filename: (path.to_owned() + "config.json").into(),
filenames: hub_load_local_safetensors(path, "model.safetensors.index.json").unwrap(),
filenames: if Path::new(&(path.to_owned() + "model.safetensors.index.json")).exists() {
hub_load_local_safetensors(path, "model.safetensors.index.json").unwrap()
} else {
//a single weight file case
let mut safetensors_files = Vec::<std::path::PathBuf>::new();
safetensors_files.insert(0, (path.to_owned() + "model.safetensors").into());
safetensors_files
},
}),
_ => loader.download_model(model_id, None, args.hf_token, args.hf_token_path)?,
};
Expand Down
Loading
Loading