From dcd09cbf8b3d7dfe72820dab1b7f615f6bd40afd Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 9 Jan 2025 03:43:13 +0000 Subject: [PATCH 1/6] Fix kernel rebuild bugs --- .gitignore | 4 +++- kernels/build.rs | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 8856ec9..21c3d8b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ __pycache__ *.ptx launch.json -libpagedattention.a \ No newline at end of file +libpagedattention.a +*.gz +kernels/src/lib.rs \ No newline at end of file diff --git a/kernels/build.rs b/kernels/build.rs index 2f37096..193cef6 100644 --- a/kernels/build.rs +++ b/kernels/build.rs @@ -21,10 +21,10 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=src/reshape_and_cache_kernel.cu"); println!("cargo:rerun-if-changed=src/marlin_cuda_kernel.cu"); println!("cargo:rerun-if-changed=src/gptq_cuda_kernel.cu"); - + let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap_or("".to_string())); let builder = bindgen_cuda::Builder::default().arg("--expt-relaxed-constexpr"); println!("cargo:info={builder:?}"); - builder.build_lib("libpagedattention.a"); + builder.build_lib(build_dir.join("libpagedattention.a")); let bindings = builder.build_ptx().unwrap(); bindings.write("src/lib.rs").unwrap(); @@ -36,6 +36,7 @@ fn main() -> Result<()> { "cargo:rustc-link-search=native={}", absolute_kernel_dir.display() ); + println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-lib=pagedattention"); println!("cargo:rustc-link-lib=dylib=cudart"); From 91890fe1ce3a0264e63d92581f17487d859c2638 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 10 Jan 2025 03:39:33 +0000 Subject: [PATCH 2/6] Initial support Multi-GPU infernce (with nccl) --- Cargo.lock | 15 +-- Cargo.toml | 4 +- src/lib.rs | 36 ++++-- src/main.rs | 129 ++++++++++++++----- src/openai/distributed.rs | 130 ++++++++++++++++++++ src/openai/mod.rs | 2 + src/openai/models/gemma.rs | 6 +- src/openai/models/llama.rs | 17 ++- src/openai/models/mistral.rs | 6 +- src/openai/models/phi2.rs | 6 +- src/openai/models/phi3.rs | 6 +- src/openai/models/quantized_llama.rs | 4 +- src/openai/models/quantized_phi3.rs | 4 +- src/openai/models/qwen2.rs | 6 +- src/openai/models/stable_lm.rs | 6 +- src/openai/models/yi.rs | 6 +- src/openai/openai_server.rs | 8 +- src/openai/pipelines/llm_engine.rs | 170 ++++++++++++++++---------- src/openai/pipelines/mod.rs | 15 ++- src/openai/pipelines/pipeline.rs | 20 +-- src/paged_attention/input_metadata.rs | 7 +- src/paged_attention/mod.rs | 2 +- src/scheduler/cache_engine.rs | 10 +- 23 files changed, 444 insertions(+), 171 deletions(-) create mode 100644 src/openai/distributed.rs diff --git a/Cargo.lock b/Cargo.lock index 3aff7a2..a20ed3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -341,7 +341,7 @@ dependencies = [ "byteorder", "candle-kernels", "candle-metal-kernels", - "cudarc 0.12.2", + "cudarc", "gemm", "half", "intel-mkl-src", @@ -485,7 +485,7 @@ dependencies = [ "candle-nn", "candle-transformers", "clap", - "cudarc 0.9.15", + "cudarc", "derive_more", "dirs", "dyn-fmt", @@ -711,15 +711,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "cudarc" -version = "0.9.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1871a911a2b9a3f66a285896a719159985683bf9903aa2cf89e0c9f53e14552" -dependencies = [ - "half", -] - [[package]] name = "cudarc" version = "0.12.2" @@ -3321,7 +3312,7 @@ version = "0.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c4dcab280ad0ef3957e153a82dcad608c954d02cf253b695322f502d1f8902e" dependencies = [ - "cudarc 0.12.2", + "cudarc", "half", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 4eee274..1beaa24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ serde_json = "1.0.108" derive_more = "0.99.17" accelerate-src = { version = "0.3.2", optional = true } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true } -cudarc = { version = "0.9.14", features = ["f16"], optional = true } +cudarc = {version = "0.12.1", features = ["f16"], optional = true } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } candle-flash-attn = { git = "https://github.com/huggingface/candle.git", version = "0.8.1", optional = true } clap = { version = "4.4.7", features = ["derive"] } @@ -48,7 +48,7 @@ kernels = {path = "./kernels", version="0.1.0", optional = true} metal-kernels = {path = "./metal-kernels", version="0.1.0", optional = true} [features] -#default = ["metal"] +default = ["nccl"] accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:kernels"] metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal", "dep:metal-kernels", "dep:metal"] diff --git a/src/lib.rs b/src/lib.rs index 82a2998..e76346f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,17 @@ #![warn(clippy::cast_lossless)] -use std::fmt::Display; - -use candle::Result; +use candle::utils::{cuda_is_available, metal_is_available}; +use candle::{Device, Result}; use candle_core as candle; use clap::Subcommand; use openai::pipelines::{pipeline::DefaultLoader, ModelLoader}; +use std::fmt::Display; use std::path::Path; +pub mod backend; +pub mod openai; +pub mod paged_attention; +pub mod scheduler; + #[derive(Debug, Subcommand)] pub enum ModelSelected { /// Select the llama model (default llama2-7b). @@ -527,7 +532,24 @@ pub fn hub_load_local_safetensors( Ok(safetensors_files) } -pub mod backend; -pub mod openai; -pub mod paged_attention; -pub mod scheduler; +pub fn new_device(cpu: bool, ordinal: usize) -> Result { + if cpu { + Ok(Device::Cpu) + } else if cuda_is_available() { + Ok(Device::new_cuda(ordinal)?) + } else if metal_is_available() { + Ok(Device::new_metal(ordinal)?) + } else { + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + { + println!( + "Running on CPU, to run on GPU(metal), build this example with `--features metal`" + ); + } + #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] + { + println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); + } + Ok(Device::Cpu) + } +} diff --git a/src/main.rs b/src/main.rs index 67795db..41020b1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,11 +4,11 @@ use axum::{ Router, }; use candle_core::{DType, Device}; -use candle_vllm::openai::openai_server::chat_completions; use candle_vllm::openai::pipelines::llm_engine::LLMEngine; use candle_vllm::openai::pipelines::pipeline::DefaultModelPaths; use candle_vllm::openai::responses::APIError; use candle_vllm::openai::OpenAIServerData; +use candle_vllm::openai::{openai_server::chat_completions, PipelineConfig}; use candle_vllm::scheduler::cache_engine::CacheConfig; use candle_vllm::scheduler::SchedulerConfig; use candle_vllm::{get_model_loader, hub_load_local_safetensors, ModelSelected}; @@ -80,6 +80,39 @@ struct Args { /// Record conversation (default false, the client need to record chat history) #[arg(long)] record_conversation: bool, + + #[arg(long, value_delimiter = ',')] + device_ids: Option>, +} + +fn get_cache_config( + kvcache_mem_gpu: usize, + kvcache_mem_cpu: usize, + block_size: usize, + config: &Config, +) -> CacheConfig { + let dsize = config.kv_cache_dtype.size_in_bytes(); + let num_gpu_blocks = kvcache_mem_gpu * SIZE_IN_MB + / dsize + / block_size + / config.num_key_value_heads + / config.get_head_size() + / config.num_hidden_layers + / 2; + let num_cpu_blocks = kvcache_mem_cpu * SIZE_IN_MB + / dsize + / block_size + / config.num_key_value_heads + / config.get_head_size() + / config.num_hidden_layers + / 2; + CacheConfig { + block_size: block_size, + num_gpu_blocks: Some(num_gpu_blocks), + num_cpu_blocks: Some(num_cpu_blocks), + fully_init: true, + dtype: config.kv_cache_dtype, + } } #[tokio::main] @@ -156,45 +189,85 @@ async fn main() -> Result<(), APIError> { None => DType::BF16, }; - let device = candle_examples::device(args.cpu).unwrap(); - let model = loader.load_model(paths, dtype, quant, device)?; - let config: Config = model.0.get_model_config(); - let dsize = config.kv_cache_dtype.size_in_bytes(); - let num_gpu_blocks = args.kvcache_mem_gpu * SIZE_IN_MB - / dsize - / args.block_size - / config.num_key_value_heads - / config.get_head_size() - / config.num_hidden_layers - / 2; - let num_cpu_blocks = args.kvcache_mem_cpu * SIZE_IN_MB - / dsize - / args.block_size - / config.num_key_value_heads - / config.get_head_size() - / config.num_hidden_layers - / 2; - let cache_config = CacheConfig { - block_size: args.block_size, - num_gpu_blocks: Some(num_gpu_blocks), - num_cpu_blocks: Some(num_cpu_blocks), - fully_init: true, - dtype: config.kv_cache_dtype, + let device_ids: Vec = match args.device_ids { + Some(ids) => ids, + _ => vec![0usize], }; + use candle_vllm::openai::pipelines::ModulePipeline; + use candle_vllm::scheduler::cache_engine::CacheEngine; + use std::collections::HashMap; + let mut pipelines = HashMap::, CacheEngine)>::new(); + use std::rc::Rc; + let mut cache_config: Option = None; + let mut config: Option = None; + let mut pipeline_config: Option = None; + + let num_shards = device_ids.len(); + #[cfg(feature = "nccl")] + use cudarc::nccl::safe::{Comm, Id}; + #[cfg(feature = "nccl")] + let id = Id::new().unwrap(); + for (rank, did) in device_ids.iter().enumerate() { + let device = candle_vllm::new_device(args.cpu, *did).unwrap(); + // let device = device.as_cuda_device().unwrap(); + #[cfg(feature = "nccl")] + let comm = match Comm::from_rank( + device.as_cuda_device().unwrap().cuda_device(), + rank, + num_shards, + id, + ) { + Ok(comm) => Rc::new(comm), + Err(err) => panic!("nccl error {:?}", err.0), + }; + println!("Loading model on device rank {}", rank); + let model = loader.load_model( + &paths, + dtype, + &quant, + device.clone(), + #[cfg(feature = "nccl")] + Some(comm), + )?; + if config.is_none() { + config = Some(model.0.get_model_config()); + } + if cache_config.is_none() { + cache_config = Some(get_cache_config( + args.kvcache_mem_gpu, + args.kvcache_mem_cpu, + args.block_size, + &config.as_ref().expect("invalid config!"), + )); + } + if pipeline_config.is_none() { + pipeline_config = Some(model.1); + } + let cache_engine = CacheEngine::new( + config.as_ref().expect("invalid config!"), + cache_config.as_ref().expect("invalid cache config!"), + cache_config.as_ref().expect("invalid cache config!").dtype, + &device, + )?; + pipelines.insert(rank, (model.0, cache_engine)); + } + let cache_config = cache_config.as_ref().unwrap().clone(); + let config = config.as_ref().unwrap().clone(); println!("Cache config {:?}", cache_config); let finish_notify = Arc::new(Notify::new()); let llm_engine = LLMEngine::new( - model.0, + pipelines, SchedulerConfig { max_num_seqs: args.max_num_seqs, }, - cache_config, + &cache_config, + &config, Arc::new(Notify::new()), finish_notify.clone(), )?; let server_data = OpenAIServerData { - pipeline_config: model.1, + pipeline_config: pipeline_config.unwrap(), model: llm_engine, record_conversation: args.record_conversation, device: Device::Cpu, diff --git a/src/openai/distributed.rs b/src/openai/distributed.rs new file mode 100644 index 0000000..8e0126d --- /dev/null +++ b/src/openai/distributed.rs @@ -0,0 +1,130 @@ +use candle_core::backend::BackendStorage; +use candle_core::CustomOp1; +use candle_core::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor}; +use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; +use candle_nn::Linear; +pub use cudarc::nccl::safe::{Comm, ReduceOp}; +pub use std::rc::Rc; +pub struct TensorParallelColumnLinear { + linear: Linear, +} + +impl TensorParallelColumnLinear { + pub fn new(linear: Linear) -> Self { + Self { linear } + } + pub fn forward(&self, x: &Tensor) -> Result { + self.linear.forward(x) + } +} + +pub struct TensorParallelRowLinear { + linear: Linear, + all_reduce: AllReduce, +} + +struct AllReduce { + comm: Rc, +} + +unsafe impl Sync for AllReduce {} +unsafe impl Send for AllReduce {} + +impl CustomOp1 for AllReduce { + fn name(&self) -> &'static str { + "allreduce" + } + + fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("AllReduce is never used on cpu") + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle_core::CudaStorage, + l: &Layout, + ) -> Result<(candle_core::CudaStorage, Shape)> { + use candle_core::cuda_backend::cudarc::driver::DeviceSlice; + use candle_core::cuda_backend::WrapErr; + use half::{bf16, f16}; + + let elem_count = l.shape().elem_count(); + let dev = s.device().clone(); + let dst = match s.dtype() { + DType::BF16 => { + let s = s.as_cuda_slice::()?; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle_core::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle_core::Error::debug)?; + candle_core::CudaStorage::wrap_cuda_slice(dst, dev) + } + DType::F16 => { + let s = s.as_cuda_slice::()?; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle_core::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle_core::Error::debug)?; + candle_core::CudaStorage::wrap_cuda_slice(dst, dev) + } + dtype => candle_core::bail!("unsupported dtype {dtype:?}"), + }; + Ok((dst, l.shape().clone())) + } +} + +impl TensorParallelRowLinear { + pub fn new(linear: Linear, comm: Rc) -> Self { + let all_reduce = AllReduce { comm }; + Self { linear, all_reduce } + } + pub fn forward(&self, x: &Tensor) -> Result { + self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce) + } +} + +pub fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard { + candle_nn::var_builder::Shard { + dim, + rank, + world_size, + } +} + +impl TensorParallelColumnLinear { + pub fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?; + Ok(Self::new(Linear::new(weight, None))) + } + + pub fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weights: Vec<_> = prefixes + .iter() + .map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size))) + .collect::>>()?; + let weight = Tensor::cat(&weights, 0)?.contiguous()?; + Ok(Self::new(Linear::new(weight, None))) + } +} + +impl TensorParallelRowLinear { + pub fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?; + Ok(Self::new(Linear::new(weight, None), comm)) + } +} diff --git a/src/openai/mod.rs b/src/openai/mod.rs index 35a1f93..b8d7053 100644 --- a/src/openai/mod.rs +++ b/src/openai/mod.rs @@ -5,6 +5,8 @@ use tokio::sync::{Mutex, Notify}; use self::{pipelines::llm_engine::LLMEngine, responses::APIError}; +#[cfg(feature = "nccl")] +pub mod distributed; pub mod requests; pub mod responses; pub mod sampling_params; diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index e0a83af..ae680cd 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -285,7 +285,7 @@ impl Attention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, softcapping: Option, ) -> Result { let (b_sz, seq_len, _) = xs.dims3()?; @@ -411,7 +411,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, softcapping: Option, ) -> Result { let residual = xs; @@ -503,7 +503,7 @@ impl Gemma { input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index e00674f..71156f0 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -9,7 +9,10 @@ use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; pub const MAX_SEQ_LEN: usize = 4096; use crate::openai::models::TokenID; +#[cfg(feature = "nccl")] +pub use cudarc::nccl::safe::Comm; use std::iter::zip; +pub use std::rc::Rc; #[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { @@ -138,7 +141,7 @@ impl CausalSelfAttention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let _enter = self.span.enter(); let (b_sz, seq_len, hidden_size) = x.dims3()?; @@ -321,7 +324,7 @@ impl Block { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let _enter = self.span.enter(); let residual = x; @@ -381,7 +384,7 @@ impl Llama { x: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (_b_sz, seq_len) = x.dims2()?; let attention_mask = if seq_len <= 1 { @@ -418,7 +421,13 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cfg: &Config, dtype: DType, device: &Device) -> Result { + pub fn load( + vb: VarBuilder, + cfg: &Config, + dtype: DType, + device: &Device, + comm: Option>, + ) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear( cfg.hidden_size, diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index 21f9d07..f0af311 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -261,7 +261,7 @@ impl Attention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len, _) = xs.dims3()?; @@ -356,7 +356,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let residual = xs; let xs = self.input_layernorm.forward(xs)?; @@ -437,7 +437,7 @@ impl Mistral { input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { diff --git a/src/openai/models/phi2.rs b/src/openai/models/phi2.rs index 1cbfd4e..4d6679b 100644 --- a/src/openai/models/phi2.rs +++ b/src/openai/models/phi2.rs @@ -250,7 +250,7 @@ impl Attention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len, _n_embd) = xs.dims3()?; let query_states = self.q_proj.forward(xs)?; @@ -341,7 +341,7 @@ impl DecoderLayer { mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let _enter = self.span.enter(); let residual = xs; @@ -411,7 +411,7 @@ impl Phi2 { xs: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len) = xs.dims2()?; let mut xs = xs.apply(&self.embed_tokens)?; diff --git a/src/openai/models/phi3.rs b/src/openai/models/phi3.rs index da14ec9..08f5dbe 100644 --- a/src/openai/models/phi3.rs +++ b/src/openai/models/phi3.rs @@ -294,7 +294,7 @@ impl Attention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len, _) = xs.dims3()?; @@ -444,7 +444,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let residual = xs; let xs = self.input_layernorm.forward(xs)?; @@ -514,7 +514,7 @@ impl Phi { input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { diff --git a/src/openai/models/quantized_llama.rs b/src/openai/models/quantized_llama.rs index 5ba0219..066f6a8 100644 --- a/src/openai/models/quantized_llama.rs +++ b/src/openai/models/quantized_llama.rs @@ -173,7 +173,7 @@ impl LayerWeights { mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.attention_wq.forward(x)?.to_dtype(self.dtype)?; @@ -524,7 +524,7 @@ impl GGUFLLaMa { x: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len) = x.dims2()?; let mask = if seq_len == 1 { diff --git a/src/openai/models/quantized_phi3.rs b/src/openai/models/quantized_phi3.rs index bf0ae78..5e13b97 100644 --- a/src/openai/models/quantized_phi3.rs +++ b/src/openai/models/quantized_phi3.rs @@ -102,7 +102,7 @@ impl LayerWeights { mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let qkv = self.attn_qkv.forward(x)?; @@ -346,7 +346,7 @@ impl GGUFPhi3 { xs: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len) = xs.dims2()?; let mask = if seq_len == 1 { diff --git a/src/openai/models/qwen2.rs b/src/openai/models/qwen2.rs index e0a945d..43b70fa 100644 --- a/src/openai/models/qwen2.rs +++ b/src/openai/models/qwen2.rs @@ -264,7 +264,7 @@ impl Attention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len, _) = xs.dims3()?; @@ -358,7 +358,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let residual = xs; let xs = self.input_layernorm.forward(xs)?; @@ -451,7 +451,7 @@ impl Qwen2 { input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { diff --git a/src/openai/models/stable_lm.rs b/src/openai/models/stable_lm.rs index 9e5a815..f07688d 100644 --- a/src/openai/models/stable_lm.rs +++ b/src/openai/models/stable_lm.rs @@ -273,7 +273,7 @@ impl Attention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len, _) = xs.dims3()?; @@ -370,7 +370,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let residual = xs; let xs = self.input_layernorm.forward(xs)?; @@ -441,7 +441,7 @@ impl StableLM { input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { diff --git a/src/openai/models/yi.rs b/src/openai/models/yi.rs index 6f0c6d3..92173bb 100644 --- a/src/openai/models/yi.rs +++ b/src/openai/models/yi.rs @@ -260,7 +260,7 @@ impl Attention { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_sz, seq_len, _) = xs.dims3()?; @@ -354,7 +354,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let residual = xs; let xs = self.ln1.forward(xs)?; @@ -425,7 +425,7 @@ impl Yi { input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index eb597a4..2ef415e 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -37,7 +37,9 @@ async fn get_gen_prompt( ) -> Result { let mut model = data.model.lock().await; let conversation = model - .get_mut_pipeline() + .get_mut_pipeline(0) + .unwrap() + .0 .get_conversation(data.record_conversation); match &request.messages { @@ -76,7 +78,9 @@ async fn check_length( let token_ids = { let model = data.model.lock().await; model - .get_pipeline() + .get_pipeline(0) + .unwrap() + .0 .tokenizer() .tokenizer() .encode(prompt, false) diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index 1e57f04..f80c346 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -9,6 +9,7 @@ use crate::openai::streaming::ChatResponse; use crate::scheduler::Scheduler; use crate::{ openai::{ + models::Config, responses::{ APIError, ChatChoice, ChatChoiceData, ChatCompletionChunk, ChatCompletionUsageResponse, Choice, ChoiceData, WrapperLogprobs, @@ -41,13 +42,12 @@ struct PreparedInputs { const _PAD_SLOT_ID: i64 = -1; pub struct LLMEngine { - pipeline: Box, + pipelines: HashMap, CacheEngine)>, scheduler: Scheduler, seq_id: usize, cache_config: CacheConfig, + config: Config, group_id: usize, - cache_engine: CacheEngine, - sliding_window: Option, pub notify: Arc, pub finish_notify: Arc, pub completion_records: HashMap, ChatCompletionUsageResponse)>, @@ -55,28 +55,20 @@ pub struct LLMEngine { impl LLMEngine { pub fn new( - pipeline: Box, + pipelines: HashMap, CacheEngine)>, scheduler_config: SchedulerConfig, - cache_config: CacheConfig, + cache_config: &CacheConfig, + config: &Config, notify: Arc, finish_notify: Arc, ) -> Result>, APIError> { - let cache_engine = CacheEngine::new( - pipeline.get_model_config(), - cache_config.clone(), - cache_config.dtype, - pipeline.device(), - )?; - let sliding_window = pipeline.get_model_config().sliding_window; - let engine = Arc::new(Mutex::new(Self { - pipeline, - scheduler: Scheduler::new(scheduler_config, &cache_config), + pipelines, + scheduler: Scheduler::new(scheduler_config, cache_config), seq_id: 0, - cache_config, + cache_config: cache_config.clone(), + config: config.clone(), group_id: 0, - cache_engine, - sliding_window, notify: notify.clone(), finish_notify: finish_notify.clone(), completion_records: HashMap::new(), @@ -145,12 +137,15 @@ impl LLMEngine { Ok(engine_clone) } - pub fn get_pipeline(&self) -> &dyn ModulePipeline { - &*self.pipeline + pub fn get_pipeline(&self, rank: usize) -> Option<&(Box, CacheEngine)> { + self.pipelines.get(&rank) } - pub fn get_mut_pipeline(&mut self) -> &mut dyn ModulePipeline { - &mut *self.pipeline + pub fn get_mut_pipeline( + &mut self, + rank: usize, + ) -> Option<&mut (Box, CacheEngine)> { + self.pipelines.get_mut(&rank) } fn get_stream_response( @@ -161,9 +156,10 @@ impl LLMEngine { finish_reason: Option, ) -> ChatCompletionChunk { let mut choices = Vec::new(); + let pipline = self.get_mut_pipeline(0).unwrap().0.as_mut(); let choice = Choice { delta: ChoiceData { - role: self.pipeline.get_conversation(true).get_roles().0.clone(), + role: pipline.get_conversation(true).get_roles().0.clone(), content, }, finish_reason, @@ -175,7 +171,7 @@ impl LLMEngine { id: request_id, choices, created, - model: self.pipeline.name().to_string(), + model: pipline.name().to_string(), object: "chat.completion.chunk", system_fingerprint: None, } @@ -194,7 +190,7 @@ impl LLMEngine { todo!(); } - self.execute_scheduler_ops(&scheduler_outputs).unwrap(); + self.execute_scheduler_ops(&scheduler_outputs, 0).unwrap(); let scheduled: &VecDeque> = &scheduler_outputs.scheduled; // for group in scheduled.iter() { @@ -205,22 +201,68 @@ impl LLMEngine { positions, metadata, } = if seqs.values().nth(0).unwrap().deref().is_prompt() { - self.prepare_prompt(scheduled) + self.prepare_prompt(scheduled, 0) } else { - self.prepare_decode(scheduled) + self.prepare_decode(scheduled, 0) } .unwrap(); + use rayon::iter::IntoParallelRefMutIterator; + use rayon::iter::ParallelIterator; + let vec_logits: Vec = self + .pipelines + .par_iter_mut() + .map(|(rank, (pipeline, cache_engine))| { + let device = pipeline.device(); + let metadata_ = if *rank == 0 { + &metadata + } else { + let context_lens = if metadata.context_lens.is_some() { + Some( + metadata + .context_lens + .as_ref() + .unwrap() + .to_device(device) + .unwrap(), + ) + } else { + metadata.context_lens.clone() + }; + let block_tables = if metadata.block_tables.is_some() { + Some( + metadata + .block_tables + .as_ref() + .unwrap() + .to_device(device) + .unwrap(), + ) + } else { + metadata.block_tables.clone() + }; - let logits = self - .pipeline - .forward( - tokens, - &positions, - Some(&*self.cache_engine.get_kv_cache()), - metadata, - ) - .unwrap(); - let results = self.pipeline.sample(logits, scheduled).unwrap(); + &InputMetadata { + //for other rank, some tensors need to be moved + slot_mapping: metadata.slot_mapping.to_device(device).unwrap(), + context_lens, + block_tables, + kv_cache_dtype: metadata.kv_cache_dtype.clone(), + prompt_lens: metadata.prompt_lens.clone(), + ..metadata + } + }; + pipeline + .forward( + tokens.clone(), + &positions, + Some(&*cache_engine.get_kv_cache()), + metadata_, + ) + .unwrap() + }) + .collect(); + let pipeline = self.get_mut_pipeline(0).unwrap().0.as_mut(); + let results = pipeline.sample(&vec_logits[0], scheduled).unwrap(); for (result_, group) in zip(results, scheduled) { match result_ { @@ -298,15 +340,15 @@ impl LLMEngine { .iter() .map(|x| x.token.try_into().unwrap()) .collect::>(); - let data = self - .pipeline + let pipeline = self.get_mut_pipeline(0usize).unwrap().0.as_mut(); + let data = pipeline .tokenizer() .tokenizer() .decode(&data, false) .unwrap(); let choice = ChatChoice { message: ChatChoiceData { - role: self.pipeline.get_conversation(true).get_roles().0.clone(), + role: pipeline.get_conversation(true).get_roles().0.clone(), content: Some(data), }, finish_reason: Some(seq.deref_mut().get_finish_reason().clone()), @@ -349,7 +391,8 @@ impl LLMEngine { } } } - self.pipeline.reset_decoder(); + let default_pipeline = self.get_mut_pipeline(0usize).unwrap().0.as_mut(); + default_pipeline.reset_decoder(); Ok(responses) } } @@ -358,21 +401,17 @@ impl LLMEngine { fn execute_scheduler_ops( &mut self, scheduler_output: &SchedulerOutput, + rank: usize, ) -> Result<(), APIError> { + let cache_engine = Box::new(&mut self.get_mut_pipeline(rank).unwrap().1); if !scheduler_output.blocks_to_swap_in.is_empty() { - try_api!(self - .cache_engine - .swap_in(scheduler_output.blocks_to_swap_in.clone())); + try_api!(cache_engine.swap_in(scheduler_output.blocks_to_swap_in.clone())); } if !scheduler_output.blocks_to_swap_out.is_empty() { - try_api!(self - .cache_engine - .swap_out(scheduler_output.blocks_to_swap_out.clone())); + try_api!(cache_engine.swap_out(scheduler_output.blocks_to_swap_out.clone())); } if !scheduler_output.blocks_to_copy.is_empty() { - try_api!(self - .cache_engine - .copy(scheduler_output.blocks_to_copy.clone())); + try_api!(cache_engine.copy(scheduler_output.blocks_to_copy.clone())); } Ok(()) } @@ -380,6 +419,7 @@ impl LLMEngine { fn prepare_prompt( &self, groups: &VecDeque>, + rank: usize, ) -> Result { let mut prompt_lens = Vec::new(); let mut input_tokens = Vec::new(); @@ -410,7 +450,7 @@ impl LLMEngine { .map(|block| block.deref_mut().block_id) .collect::>(); - let start_idx = if let Some(sliding_window) = self.sliding_window { + let start_idx = if let Some(sliding_window) = self.config.sliding_window { if prompt_len > sliding_window { 0.min(prompt_len - sliding_window) } else { @@ -444,6 +484,7 @@ impl LLMEngine { slot_mappings.push(slot_mapping); } } + let device = self.get_pipeline(rank).unwrap().0.device(); let max_prompt_len = prompt_lens.iter().max().unwrap(); let input_tokens = _make_tensor_with_pad( @@ -453,14 +494,10 @@ impl LLMEngine { .collect::>(), *max_prompt_len, 0, - self.pipeline.device(), - )?; - let slot_mapping = _make_tensor_with_pad( - slot_mappings, - *max_prompt_len, - _PAD_SLOT_ID, - self.pipeline.device(), + device, )?; + let slot_mapping = + _make_tensor_with_pad(slot_mappings, *max_prompt_len, _PAD_SLOT_ID, device)?; Ok(PreparedInputs { tokens: input_tokens, @@ -471,7 +508,6 @@ impl LLMEngine { max_context_len: None, context_lens: None, block_tables: None, - attn_bias: None, is_prompt: true, kv_cache_dtype: "auto".to_string(), // TODO(EricLBuehler): specialize for models }, @@ -481,6 +517,7 @@ impl LLMEngine { fn prepare_decode( &self, groups: &VecDeque>, + rank: usize, ) -> Result { let mut input_tokens = Vec::new(); let mut input_positions = Vec::new(); @@ -495,7 +532,7 @@ impl LLMEngine { let position = seq.deref_mut().get_len() - 1; input_positions.push(vec![position]); - let context_len = if let Some(sliding_window) = self.sliding_window { + let context_len = if let Some(sliding_window) = self.config.sliding_window { seq.deref_mut().get_len().min(sliding_window) } else { seq.deref_mut().get_len() @@ -523,7 +560,7 @@ impl LLMEngine { let slot = slot.try_into().unwrap(); slot_mappings.push(vec![slot]); - if let Some(sliding_window) = self.sliding_window { + if let Some(sliding_window) = self.config.sliding_window { let sliding_window_blocks = sliding_window / self.cache_config.block_size; let slide_idx = if table.len() > sliding_window_blocks { table.len() - sliding_window_blocks @@ -536,6 +573,7 @@ impl LLMEngine { } } } + let device = self.get_pipeline(rank).unwrap().0.device(); let input_tokens = _make_tensor_with_pad( input_tokens @@ -544,16 +582,15 @@ impl LLMEngine { .collect::>(), 1, 0, - self.pipeline.device(), + device, )?; - let slot_mapping = - _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, self.pipeline.device())?; + let slot_mapping = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?; let max_context_len = context_lens.iter().max().unwrap(); let context_lens = try_api!(Tensor::from_vec( context_lens.iter().map(|x| *x as u32).collect::>(), (context_lens.len(),), - self.pipeline.device(), + device, )); let max_block_table_len = block_tables.iter().map(|x| x.len()).max().unwrap(); @@ -564,7 +601,7 @@ impl LLMEngine { .collect::>(), max_block_table_len, 0, - self.pipeline.device(), + device, )?; let block_tables = try_api!(block_tables.reshape(((), max_block_table_len))); Ok(PreparedInputs { @@ -576,7 +613,6 @@ impl LLMEngine { max_context_len: Some(*max_context_len), context_lens: Some(context_lens), block_tables: Some(block_tables), - attn_bias: None, is_prompt: false, kv_cache_dtype: "auto".to_string(), // TODO(EricLBuehler): specialize for models }, diff --git a/src/openai/pipelines/mod.rs b/src/openai/pipelines/mod.rs index 9d8b7ee..7793bf3 100644 --- a/src/openai/pipelines/mod.rs +++ b/src/openai/pipelines/mod.rs @@ -14,25 +14,29 @@ pub mod llm_engine; pub mod pipeline; use crate::scheduler::sequence::SequenceGroup; type TokenOrFinishReason = Either; +#[cfg(feature = "nccl")] +pub use cudarc::nccl::safe::Comm; use std::collections::VecDeque; +pub use std::rc::Rc; + pub trait ModulePipeline: Send + Sync { fn forward( &mut self, input_tokens: Tensor, input_positions: &[Vec], kv_cache: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: InputMetadata, + input_metadata: &InputMetadata, ) -> Result; fn sample( &mut self, - logits: Tensor, + logits: &Tensor, groups: &VecDeque>, ) -> Result, APIError>; fn sample_batch( &mut self, - logits: Tensor, + logits: &Tensor, groups: &VecDeque>, ) -> Result, APIError>; @@ -112,9 +116,10 @@ pub trait ModelLoader { fn load_model( &self, - paths: Box, + paths: &Box, dtype: DType, - quant: Option, + quant: &Option, device: Device, + #[cfg(feature = "nccl")] comm: Option>, ) -> Result<(Box, PipelineConfig), APIError>; } diff --git a/src/openai/pipelines/pipeline.rs b/src/openai/pipelines/pipeline.rs index a2673a6..a4ed7f1 100644 --- a/src/openai/pipelines/pipeline.rs +++ b/src/openai/pipelines/pipeline.rs @@ -34,11 +34,14 @@ use candle_core::quantized::gguf_file; use candle_core::{DType, Device, IndexOp, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; +#[cfg(feature = "nccl")] +pub use cudarc::nccl::safe::Comm; use either::Either; use either::Either::{Left, Right}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use rayon::prelude::*; use std::collections::VecDeque; +pub use std::rc::Rc; use std::{path::PathBuf, sync::Arc}; use tokenizers::Tokenizer; const EOS_TOKEN: &str = ""; @@ -138,16 +141,17 @@ impl ModelLoader for DefaultLoader { fn load_model( &self, - paths: Box, + paths: &Box, dtype: DType, - quant: Option, + quant: &Option, device: Device, + #[cfg(feature = "nccl")] comm: Option>, ) -> Result<(Box, PipelineConfig), APIError> { let specific_args = self.config.clone(); let mut stop_token_ids = Vec::::new(); let (model, config, tokenizer, sep_style) = if quant.is_some() - && matches!(quant.unwrap().as_str(), "ggml" | "gguf") + && matches!(quant.as_ref().unwrap().as_str(), "ggml" | "gguf") { let path = paths.get_weight_filenames()[0].clone(); println!( @@ -255,11 +259,11 @@ impl ModelLoader for DefaultLoader { let (model, sep_style) = match self.name.as_str() { "llama" => ( - LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device))), + LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device, comm))), SeparatorStyle::Llama, ), "llama3" => ( - LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device))), + LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device, comm))), SeparatorStyle::Llama3, ), "phi2" => ( @@ -406,7 +410,7 @@ impl ModulePipeline for DefaultPipeline { input_tokens: Tensor, input_positions: &[Vec], kv_cache: Option<&Vec<(Tensor, Tensor)>>, - mut input_metadata: InputMetadata, + mut input_metadata: &InputMetadata, ) -> Result { let input_tokens = if input_tokens.shape().dims().len() < 2 { input_tokens @@ -502,7 +506,7 @@ impl ModulePipeline for DefaultPipeline { fn sample( &mut self, - logits: Tensor, + logits: &Tensor, groups: &VecDeque>, ) -> Result, APIError> { use std::collections::HashMap; @@ -604,7 +608,7 @@ impl ModulePipeline for DefaultPipeline { fn sample_batch( &mut self, - logits: Tensor, + logits: &Tensor, groups: &VecDeque>, ) -> Result, APIError> { use std::collections::HashMap; diff --git a/src/paged_attention/input_metadata.rs b/src/paged_attention/input_metadata.rs index 6e09f9d..5daaba4 100644 --- a/src/paged_attention/input_metadata.rs +++ b/src/paged_attention/input_metadata.rs @@ -1,14 +1,11 @@ use candle_core::Tensor; - -use super::attn_bias::AttentionBiasBlockDiagonal; - pub struct InputMetadata { pub prompt_lens: Vec, pub max_context_len: Option, pub block_tables: Option, pub context_lens: Option, pub slot_mapping: Tensor, - pub attn_bias: Option>, + // pub attn_bias: Option, pub is_prompt: bool, pub kv_cache_dtype: String, } @@ -35,7 +32,7 @@ impl InputMetadata { block_tables, context_lens, slot_mapping, - attn_bias: None, + // attn_bias: None, is_prompt, kv_cache_dtype, } diff --git a/src/paged_attention/mod.rs b/src/paged_attention/mod.rs index 40585cf..eff3a1b 100644 --- a/src/paged_attention/mod.rs +++ b/src/paged_attention/mod.rs @@ -66,7 +66,7 @@ impl PagedAttention { attention_mask: Option<&Tensor>, mut key_cache: Option, mut value_cache: Option, - input_metadata: &mut InputMetadata, + input_metadata: &InputMetadata, softcapping: Option, ) -> Result { let dims = input_metadata.slot_mapping.dims(); diff --git a/src/scheduler/cache_engine.rs b/src/scheduler/cache_engine.rs index 5ef85b4..7b7b717 100644 --- a/src/scheduler/cache_engine.rs +++ b/src/scheduler/cache_engine.rs @@ -45,19 +45,19 @@ pub struct CacheEngine { impl CacheEngine { pub fn new( - model_config: Config, - cache_config: CacheConfig, + model_config: &Config, + cache_config: &CacheConfig, dtype: DType, device: &Device, ) -> Result { Ok(Self { gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache( - &model_config, - &cache_config, + model_config, + cache_config, dtype, device, )?)), - cpu_cache: Self::allocate_cpu_cache(&model_config, &cache_config, dtype, device)?, + cpu_cache: Self::allocate_cpu_cache(model_config, cache_config, dtype, device)?, num_layers: model_config.num_hidden_layers, }) } From 9a446289d3ffe8e8b8c4838db455a6a9bd162e1b Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 14 Jan 2025 08:36:10 +0000 Subject: [PATCH 3/6] Multi-GPU inference, parallel model shards, simplified pipeline (got this to work) --- Cargo.toml | 3 +- README.md | 22 +- kernels/src/ffi.rs | 15 +- kernels/src/gptq_cuda_kernel.cu | 4 +- kernels/src/lib.rs | 2 +- kernels/src/marlin_cuda_kernel.cu | 17 +- kernels/src/pagedattention.cu | 22 +- kernels/src/reshape_and_cache_kernel.cu | 6 +- src/backend/gptq.rs | 13 +- src/backend/paged_attention.rs | 5 +- src/lib.rs | 15 +- src/main.rs | 100 ++-- .../conversation/default_conversation.rs | 2 +- src/openai/models/gemma.rs | 10 +- src/openai/models/llama.rs | 20 +- src/openai/models/llama_multi.rs | 422 +++++++++++++++ src/openai/models/mistral.rs | 10 +- src/openai/models/mod.rs | 2 + src/openai/models/phi2.rs | 10 +- src/openai/models/phi3.rs | 10 +- src/openai/models/quantized_llama.rs | 8 +- src/openai/models/quantized_phi3.rs | 8 +- src/openai/models/qwen2.rs | 10 +- src/openai/models/stable_lm.rs | 10 +- src/openai/models/yi.rs | 10 +- src/openai/pipelines/llm_engine.rs | 144 +++-- src/openai/pipelines/mod.rs | 78 +-- src/openai/pipelines/pipeline.rs | 496 ++++++++++-------- src/paged_attention/mod.rs | 10 +- src/scheduler/cache_engine.rs | 39 +- 30 files changed, 975 insertions(+), 548 deletions(-) create mode 100644 src/openai/models/llama_multi.rs diff --git a/Cargo.toml b/Cargo.toml index 1beaa24..e4b9d6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ serde_json = "1.0.108" derive_more = "0.99.17" accelerate-src = { version = "0.3.2", optional = true } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true } -cudarc = {version = "0.12.1", features = ["f16"], optional = true } +cudarc = {version = "0.12.2", features = ["f16"], optional = true } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } candle-flash-attn = { git = "https://github.com/huggingface/candle.git", version = "0.8.1", optional = true } clap = { version = "4.4.7", features = ["derive"] } @@ -48,7 +48,6 @@ kernels = {path = "./kernels", version="0.1.0", optional = true} metal-kernels = {path = "./metal-kernels", version="0.1.0", optional = true} [features] -default = ["nccl"] accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:kernels"] metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal", "dep:metal-kernels", "dep:metal"] diff --git a/README.md b/README.md index e4b1a8e..4440de4 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a - `In-situ` quantization (and `In-situ` marlin format conversion) - `GPTQ/Marlin` format quantization (4-bit) - Support `Mac/Metal` devices +- Support `Multi-GPU` inference ## Develop Status @@ -43,7 +44,7 @@ https://github.com/user-attachments/assets/66b5b90e-e2ca-4f0b-82d7-99aa9f85568c ## Usage See [this folder](examples/) for some examples. -### Step 1: Run Candle-VLLM service (assume llama2-7b model weights downloaded) +### Step 1: Run Candle-VLLM service ``` curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh @@ -53,6 +54,7 @@ git clone git@github.com:EricLBuehler/candle-vllm.git cd candle-vllm cargo run --release --features cuda -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --temperature 0. --penalty 1.0 ``` +Note: assume Llama-3.1-8B model weights downloaded in folder `/home/Meta-Llama-3.1-8B-Instruct/` You may also run specific model using huggingface model-id, e.g., ```shell @@ -69,8 +71,22 @@ cargo run --release --features metal -- --port 2000 --dtype bf16 --weight-path / __Refer to Marlin quantization below for running quantized GPTQ models.__ +Run `Multi-GPU` inference with NCCL feature + +```shell +cargo run --release --features cuda,nccl -- --port 2000 --device-ids "0,1" --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --temperature 0. --penalty 1.0 +``` + +If you encoutered problems under Multi-GPU setttings, you may: +```shell +export NCCL_P2P_LEVEL=LOC # use local devices (mutiple cards within a server, PCIE, etc.) +export NCCL_P2P_DISABLE=1 # diable p2p cause this feature can cause illegal memory access in certain environments +export NCCL_IB_DISABLE=1 # diable ibnet/infiniband (optional) +``` +**Note:** quantized models are not supported yet under multi-gpu setting. + ### Step 2: -#### Option 1: Chat with Chat.py (recommended) +#### Option 1: Chat with Chat.py (for simple tests) Install API and chatbot dependencies (openai package is only used for local chat with candle-vllm) ```shell @@ -92,7 +108,7 @@ Chat demo on Apple M4 (Phi3 3.8B) -#### Option 2: Chat with ChatUI +#### Option 2: Chat with ChatUI (recommended) Install ChatUI and its dependencies: ``` diff --git a/kernels/src/ffi.rs b/kernels/src/ffi.rs index a8a5a74..c014eb8 100644 --- a/kernels/src/ffi.rs +++ b/kernels/src/ffi.rs @@ -15,8 +15,8 @@ extern "C" { x: c_int, key_stride: c_int, value_stride: c_int, - dtype: u32, + stream: i64, ); pub fn paged_attention_v1( @@ -41,6 +41,7 @@ extern "C" { dtype: u32, softscapping: f32, + stream: i64, ); pub fn paged_attention_v2( @@ -68,6 +69,7 @@ extern "C" { dtype: u32, softscapping: f32, + stream: i64, ); pub fn marlin_4bit_f16( @@ -80,6 +82,7 @@ extern "C" { n: c_int, workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero groupsize: c_int, + stream: i64, ); pub fn marlin_4bit_bf16( @@ -92,9 +95,16 @@ extern "C" { n: c_int, workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero groupsize: c_int, + stream: i64, ); - pub fn gptq_repack(weight: *const c_void, result: *const c_void, m: c_int, n: c_int); + pub fn gptq_repack( + weight: *const c_void, + result: *const c_void, + m: c_int, + n: c_int, + stream: i64, + ); pub fn gemm_half_q_half_alt( a: *const c_void, @@ -107,5 +117,6 @@ extern "C" { n: i32, k: i32, bit: i32, + stream: i64, ); } diff --git a/kernels/src/gptq_cuda_kernel.cu b/kernels/src/gptq_cuda_kernel.cu index 90a869e..b719c56 100644 --- a/kernels/src/gptq_cuda_kernel.cu +++ b/kernels/src/gptq_cuda_kernel.cu @@ -199,7 +199,7 @@ extern "C" void gemm_half_q_half_alt(const void* a, const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const void* b_gptq_scales, const int* b_g_idx, void* c, int size_m, int size_n, int size_k, - int bit) { + int bit, int64_t stream_) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -213,7 +213,7 @@ extern "C" void gemm_half_q_half_alt(const void* a, const uint32_t* b_q_weight, kernel = gemm_half_q_half_alt_8bit_kernel; } - const cudaStream_t stream = 0; + const cudaStream_t stream = (cudaStream_t)stream_; kernel<<>>( (const half2*)(const half*)a, b_q_weight, (half*)c, (const half*)b_gptq_scales, b_gptq_qzeros, b_g_idx, size_m, size_k / 32 * bit, size_n); diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index b38bc92..c2f90d3 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -1,9 +1,9 @@ pub const COPY_BLOCKS_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx")); +pub const GPTQ_CUDA_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/gptq_cuda_kernel.ptx")); pub const MARLIN_CUDA_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/marlin_cuda_kernel.ptx")); pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx")); pub const RESHAPE_AND_CACHE_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx")); -pub const GPTQ_CUDA_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/gptq_cuda_kernel.ptx")); pub mod ffi; diff --git a/kernels/src/marlin_cuda_kernel.cu b/kernels/src/marlin_cuda_kernel.cu index 0670cd7..58c785a 100644 --- a/kernels/src/marlin_cuda_kernel.cu +++ b/kernels/src/marlin_cuda_kernel.cu @@ -865,11 +865,11 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { template void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k, - int prob_n, void* workspace, int groupsize + int prob_n, void* workspace, int groupsize, int64_t stream_ ) { int dev = 0; - cudaStream_t stream = 0; + cudaStream_t stream = (cudaStream_t)stream_; int thread_k = -1; int thread_n = -1; int sms = -1; @@ -950,15 +950,15 @@ void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, i } extern "C" void marlin_4bit_f16(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k, - int prob_n, void* workspace, int groupsize + int prob_n, void* workspace, int groupsize, int64_t stream ) { - marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize); + marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize, stream); } extern "C" void marlin_4bit_bf16(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k, - int prob_n, void* workspace, int groupsize + int prob_n, void* workspace, int groupsize, int64_t stream ) { - marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize); + marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize, stream); } @@ -1025,7 +1025,8 @@ extern "C" void gptq_repack( void* in, void* out, int m, - int n + int n, + int64_t stream_ ) { assert(m % 2 == 0); @@ -1033,7 +1034,7 @@ extern "C" void gptq_repack( const dim3 threads(32); // marlin packs 16 x 64 block and gptq packs 8 x 1 const dim3 blocks(m / 2, n / 64); - cudaStream_t stream = 0; + cudaStream_t stream = (cudaStream_t)stream_; gptq_repack_kernel<<>>( (uint32_t*)in, (uint32_t*)out, diff --git a/kernels/src/pagedattention.cu b/kernels/src/pagedattention.cu index 6127677..c3c6710 100644 --- a/kernels/src/pagedattention.cu +++ b/kernels/src/pagedattention.cu @@ -611,7 +611,8 @@ void paged_attention_v1_launcher( int q_stride, int kv_block_stride, int kv_head_stride, - float softscapping + float softscapping, + int64_t stream_ ) { // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -630,7 +631,7 @@ void paged_attention_v1_launcher( dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); - const cudaStream_t stream = 0; + const cudaStream_t stream = (cudaStream_t)stream_; switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this @@ -676,7 +677,8 @@ void paged_attention_v1_launcher( q_stride, \ kv_block_stride, \ kv_head_stride, \ - softscapping); + softscapping, \ + stream); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -716,7 +718,8 @@ extern "C" void paged_attention_v1( int32_t kv_head_stride, uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32 - float softscapping + float softscapping, + int64_t stream ) { if (dtype == 2) { CALL_V1_LAUNCHER_BLOCK_SIZE(float); @@ -781,7 +784,8 @@ void paged_attention_v2_launcher( int q_stride, int kv_block_stride, int kv_head_stride, - float softscapping + float softscapping, + int64_t stream_ ) { // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -803,7 +807,7 @@ void paged_attention_v2_launcher( int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); - const cudaStream_t stream = 0; + const cudaStream_t stream = (cudaStream_t)stream_; switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this @@ -852,7 +856,8 @@ void paged_attention_v2_launcher( q_stride, \ kv_block_stride, \ kv_head_stride,\ - softscapping); + softscapping, \ + stream); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -895,7 +900,8 @@ extern "C" void paged_attention_v2( int32_t kv_head_stride, uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32 - float softscapping + float softscapping, + int64_t stream ) { if (dtype == 2) { CALL_V2_LAUNCHER_BLOCK_SIZE(float); diff --git a/kernels/src/reshape_and_cache_kernel.cu b/kernels/src/reshape_and_cache_kernel.cu index 7f0fd84..1e5635a 100644 --- a/kernels/src/reshape_and_cache_kernel.cu +++ b/kernels/src/reshape_and_cache_kernel.cu @@ -89,13 +89,13 @@ extern "C" void call_reshape_and_cache( int32_t x, int32_t key_stride, int32_t value_stride, - - uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 + uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32 + int64_t stream_ ) { dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); - const cudaStream_t stream = 0; + const cudaStream_t stream = (cudaStream_t)stream_; if (dtype == 0){ CALL_RESHAPE_AND_CACHE(uint16_t); diff --git a/src/backend/gptq.rs b/src/backend/gptq.rs index 8602433..3985860 100644 --- a/src/backend/gptq.rs +++ b/src/backend/gptq.rs @@ -94,6 +94,7 @@ impl GPTQMatMul { size_n as i32, //n workspace_ptr, groupsize as i32, + *dev.cu_stream() as i64, ); } else if x.dtype() == DType::BF16 { marlin_4bit_bf16( @@ -106,6 +107,7 @@ impl GPTQMatMul { size_n as i32, //n workspace_ptr, groupsize as i32, + *dev.cu_stream() as i64, ); } } else { @@ -145,6 +147,7 @@ impl GPTQMatMul { size_n as i32, size_k as i32, self.bits, + *dev.cu_stream() as i64, ) } else { candle::bail!("GPTQMatMul is only supported for f16 non-marlin matmul. Use '--dtype f16' parameter instead."); @@ -244,7 +247,15 @@ impl GPTQRepack { let out_ptr = *out.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void; - unsafe { gptq_repack(q_ptr, out_ptr, q_shape[0] as i32, q_shape[1] as i32) } + unsafe { + gptq_repack( + q_ptr, + out_ptr, + q_shape[0] as i32, + q_shape[1] as i32, + *dev.cu_stream() as i64, + ) + } let out = CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, oshape)) diff --git a/src/backend/paged_attention.rs b/src/backend/paged_attention.rs index 9cb7029..93d0000 100644 --- a/src/backend/paged_attention.rs +++ b/src/backend/paged_attention.rs @@ -191,6 +191,7 @@ impl PagedAttention { kv_head_stride as c_int, internal_type, self.softcapping, + *dev.cu_stream() as i64, ) } } else { @@ -228,6 +229,7 @@ impl PagedAttention { kv_head_stride as c_int, internal_type, self.softcapping, + *dev.cu_stream() as i64, ) } } @@ -555,6 +557,7 @@ impl ReshapeCache { ) -> Result<()> { use candle::cuda_backend::cudarc::driver::DevicePtr; let dtype = k.dtype(); + let dev = k.device(); let internal_type = match dtype { DType::F16 => 0, DType::BF16 => 1, @@ -663,7 +666,6 @@ impl ReshapeCache { let kc_ptr = *kc.device_ptr() as *const core::ffi::c_void; let vc_ptr = *vc.device_ptr() as *const core::ffi::c_void; let s_ptr = *s.device_ptr() as *const core::ffi::c_long; - unsafe { kernels::ffi::call_reshape_and_cache( k_ptr, @@ -679,6 +681,7 @@ impl ReshapeCache { key_stride, value_stride, internal_type, + *dev.cu_stream() as i64, ) } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index e76346f..5e11781 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ use candle::utils::{cuda_is_available, metal_is_available}; use candle::{Device, Result}; use candle_core as candle; use clap::Subcommand; -use openai::pipelines::{pipeline::DefaultLoader, ModelLoader}; +use openai::pipelines::pipeline::DefaultLoader; use std::fmt::Display; use std::path::Path; @@ -11,7 +11,6 @@ pub mod backend; pub mod openai; pub mod paged_attention; pub mod scheduler; - #[derive(Debug, Subcommand)] pub enum ModelSelected { /// Select the llama model (default llama2-7b). @@ -266,7 +265,7 @@ impl SpecificConfig { pub fn get_model_loader( selected_model: ModelSelected, model_id: Option, -) -> (Box, String, Option) { +) -> (Box, String, Option) { match selected_model { ModelSelected::Llama { repeat_last_n, @@ -532,11 +531,11 @@ pub fn hub_load_local_safetensors( Ok(safetensors_files) } -pub fn new_device(cpu: bool, ordinal: usize) -> Result { - if cpu { - Ok(Device::Cpu) - } else if cuda_is_available() { - Ok(Device::new_cuda(ordinal)?) +pub fn new_device(ordinal: usize) -> Result { + if cuda_is_available() { + use candle_core::CudaDevice; + let device = Device::Cuda(CudaDevice::new_with_stream(ordinal).unwrap()); + Ok(device) } else if metal_is_available() { Ok(Device::new_metal(ordinal)?) } else { diff --git a/src/main.rs b/src/main.rs index 41020b1..fbaba8c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,11 +4,11 @@ use axum::{ Router, }; use candle_core::{DType, Device}; +use candle_vllm::openai::openai_server::chat_completions; use candle_vllm::openai::pipelines::llm_engine::LLMEngine; use candle_vllm::openai::pipelines::pipeline::DefaultModelPaths; use candle_vllm::openai::responses::APIError; use candle_vllm::openai::OpenAIServerData; -use candle_vllm::openai::{openai_server::chat_completions, PipelineConfig}; use candle_vllm::scheduler::cache_engine::CacheConfig; use candle_vllm::scheduler::SchedulerConfig; use candle_vllm::{get_model_loader, hub_load_local_safetensors, ModelSelected}; @@ -90,19 +90,20 @@ fn get_cache_config( kvcache_mem_cpu: usize, block_size: usize, config: &Config, + num_shards: usize, ) -> CacheConfig { let dsize = config.kv_cache_dtype.size_in_bytes(); let num_gpu_blocks = kvcache_mem_gpu * SIZE_IN_MB / dsize / block_size - / config.num_key_value_heads + / (config.num_key_value_heads / num_shards) / config.get_head_size() / config.num_hidden_layers / 2; let num_cpu_blocks = kvcache_mem_cpu * SIZE_IN_MB / dsize / block_size - / config.num_key_value_heads + / (config.num_key_value_heads / num_shards) / config.get_head_size() / config.num_hidden_layers / 2; @@ -125,7 +126,7 @@ async fn main() -> Result<(), APIError> { let paths = match (&args.weight_path, &args.weight_file) { //model in a folder (safetensor format, huggingface folder structure) - (Some(path), None) => Box::new(DefaultModelPaths { + (Some(path), None) => DefaultModelPaths { tokenizer_filename: Path::new(path).join("tokenizer.json"), config_filename: Path::new(path).join("config.json"), filenames: if Path::new(path) @@ -139,9 +140,9 @@ async fn main() -> Result<(), APIError> { safetensors_files.insert(0, Path::new(path).join("model.safetensors")); safetensors_files }, - }), + }, //model in a quantized file (gguf/ggml format) - (Some(path), Some(file)) => Box::new(DefaultModelPaths { + (Some(path), Some(file)) => DefaultModelPaths { tokenizer_filename: { //we need to download tokenizer for the ggufl/ggml model let api = hf_hub::api::sync::Api::new().unwrap(); @@ -154,7 +155,7 @@ async fn main() -> Result<(), APIError> { } else { panic!("Model file not found {}", file); }, - }), + }, _ => { if args.hf_token.is_none() && args.hf_token_path.is_none() { //no token provided @@ -193,64 +194,43 @@ async fn main() -> Result<(), APIError> { Some(ids) => ids, _ => vec![0usize], }; - use candle_vllm::openai::pipelines::ModulePipeline; + let num_shards = device_ids.len(); use candle_vllm::scheduler::cache_engine::CacheEngine; - use std::collections::HashMap; - let mut pipelines = HashMap::, CacheEngine)>::new(); - use std::rc::Rc; - let mut cache_config: Option = None; + let (default_pipelines, pipeline_config) = + loader.load_model(paths, dtype, &quant, device_ids).await?; + let mut config: Option = None; - let mut pipeline_config: Option = None; + let mut cache_config: Option = None; - let num_shards = device_ids.len(); - #[cfg(feature = "nccl")] - use cudarc::nccl::safe::{Comm, Id}; - #[cfg(feature = "nccl")] - let id = Id::new().unwrap(); - for (rank, did) in device_ids.iter().enumerate() { - let device = candle_vllm::new_device(args.cpu, *did).unwrap(); - // let device = device.as_cuda_device().unwrap(); - #[cfg(feature = "nccl")] - let comm = match Comm::from_rank( - device.as_cuda_device().unwrap().cuda_device(), - rank, - num_shards, - id, - ) { - Ok(comm) => Rc::new(comm), - Err(err) => panic!("nccl error {:?}", err.0), - }; - println!("Loading model on device rank {}", rank); - let model = loader.load_model( - &paths, - dtype, - &quant, - device.clone(), - #[cfg(feature = "nccl")] - Some(comm), - )?; - if config.is_none() { - config = Some(model.0.get_model_config()); - } - if cache_config.is_none() { - cache_config = Some(get_cache_config( + let pipelines = default_pipelines + .into_iter() + .map(|pipeline| { + let cfg = pipeline.get_model_config(); + let cache_cfg = get_cache_config( args.kvcache_mem_gpu, args.kvcache_mem_cpu, args.block_size, - &config.as_ref().expect("invalid config!"), - )); - } - if pipeline_config.is_none() { - pipeline_config = Some(model.1); - } - let cache_engine = CacheEngine::new( - config.as_ref().expect("invalid config!"), - cache_config.as_ref().expect("invalid cache config!"), - cache_config.as_ref().expect("invalid cache config!").dtype, - &device, - )?; - pipelines.insert(rank, (model.0, cache_engine)); - } + &cfg, + num_shards, + ); + let cache_engine = CacheEngine::new( + &cfg, + &cache_cfg, + cache_cfg.dtype, + &pipeline.device(), + num_shards, + ) + .unwrap(); + if config.is_none() { + config = Some(cfg.clone()); + } + if cache_config.is_none() { + cache_config = Some(cache_cfg.clone()); + } + (pipeline.rank(), (pipeline, cache_engine)) + }) + .collect(); + let cache_config = cache_config.as_ref().unwrap().clone(); let config = config.as_ref().unwrap().clone(); println!("Cache config {:?}", cache_config); @@ -267,7 +247,7 @@ async fn main() -> Result<(), APIError> { )?; let server_data = OpenAIServerData { - pipeline_config: pipeline_config.unwrap(), + pipeline_config, model: llm_engine, record_conversation: args.record_conversation, device: Device::Cpu, diff --git a/src/openai/conversation/default_conversation.rs b/src/openai/conversation/default_conversation.rs index 232735b..6990d02 100644 --- a/src/openai/conversation/default_conversation.rs +++ b/src/openai/conversation/default_conversation.rs @@ -7,7 +7,7 @@ pub const SYSTEM_TEMPLATE: &str = "{}"; pub const DEFAULT_SEP: &str = "\n"; /// Separator style for default conversation. -#[derive(Default)] +#[derive(Default, Clone)] pub enum SeparatorStyle { #[default] AddColonSingle, diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index ae680cd..736b857 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -280,7 +280,7 @@ impl Attention { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -406,7 +406,7 @@ impl DecoderLayer { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -499,7 +499,7 @@ impl Gemma { } pub fn forward( - &mut self, + &self, input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -515,7 +515,7 @@ impl Gemma { let xs = self.embed_tokens.forward(input_ids)?; let mut xs = (xs * (self.hidden_size as f64).sqrt())?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { xs = layer.forward( &xs, attention_mask.as_ref(), @@ -526,7 +526,7 @@ impl Gemma { )? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { xs = layer.forward( &xs, attention_mask.as_ref(), diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index 71156f0..050d3b2 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -9,8 +9,6 @@ use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; pub const MAX_SEQ_LEN: usize = 4096; use crate::openai::models::TokenID; -#[cfg(feature = "nccl")] -pub use cudarc::nccl::safe::Comm; use std::iter::zip; pub use std::rc::Rc; @@ -136,7 +134,7 @@ impl CausalSelfAttention { } fn forward( - &mut self, + &self, x: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -319,7 +317,7 @@ struct Block { impl Block { fn forward( - &mut self, + &self, x: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -380,7 +378,7 @@ impl Llama { } pub fn forward( - &mut self, + &self, x: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -395,7 +393,7 @@ impl Llama { }; let mut x = self.wte.forward(x)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), block) in zip(kv_caches.iter(), &mut self.blocks) { + for ((k_cache, v_cache), block) in zip(kv_caches.iter(), &self.blocks) { x = block.forward( &x, attention_mask.as_ref(), @@ -405,7 +403,7 @@ impl Llama { )?; } } else { - for block in &mut self.blocks { + for block in &self.blocks { x = block.forward( &x, attention_mask.as_ref(), @@ -421,13 +419,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load( - vb: VarBuilder, - cfg: &Config, - dtype: DType, - device: &Device, - comm: Option>, - ) -> Result { + pub fn load(vb: VarBuilder, cfg: &Config, dtype: DType, device: &Device) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear( cfg.hidden_size, diff --git a/src/openai/models/llama_multi.rs b/src/openai/models/llama_multi.rs new file mode 100644 index 0000000..6d738fe --- /dev/null +++ b/src/openai/models/llama_multi.rs @@ -0,0 +1,422 @@ +use super::{Config, QuantConfig}; +use crate::openai::distributed::{shard, TensorParallelColumnLinear, TensorParallelRowLinear}; +use crate::paged_attention::input_metadata::InputMetadata; +use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; +use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle_core as candle; +use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; +use candle_nn::{Embedding, Linear, Module, RmsNorm}; +pub use cudarc::nccl::safe::{Comm, ReduceOp}; +pub const MAX_SEQ_LEN: usize = 4096; +use crate::openai::models::TokenID; +use std::iter::zip; +use std::rc::Rc; +#[derive(Debug, Clone, serde::Deserialize)] +pub struct LlamaConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: Option, + pub rms_norm_eps: f64, + #[serde(default = "default_rope")] + pub rope_theta: f32, + pub bos_token_id: TokenID, + pub eos_token_id: TokenID, + pub max_position_embeddings: Option, + pub quantization_config: Option, +} + +fn default_rope() -> f32 { + 10_000.0 +} + +impl LlamaConfig { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { + Config { + hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), + rms_norm_eps: self.rms_norm_eps, + rope_theta: f64::from(self.rope_theta), + use_flash_attn, + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, + max_seq_len: self.max_position_embeddings.unwrap_or(MAX_SEQ_LEN), + sliding_window: None, + hidden_act: None, + tie_word_embeddings: false, + rope_scaling: None, + original_max_position_embeddings: None, + attention_bias: false, + partial_rotary_factor: None, + qk_layer_rms_norm: None, + kv_cache_dtype, + use_qkv_bias: None, + custom_stop_tokens: None, + specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, + quantization_config: self.quantization_config, + } + } +} + +#[derive(Debug, Clone)] +pub struct Cache { + cos: Tensor, + sin: Tensor, +} + +impl Cache { + pub fn new(dtype: DType, config: &Config, device: &Device) -> Result { + // precompute freqs_cis + let n_elem = config.hidden_size / config.num_attention_heads; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / config.rope_theta.powf(i as f64 / n_elem as f64) as f32) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, config.max_seq_len as u32, device)? + .to_dtype(DType::F32)? + .reshape((config.max_seq_len, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { cos, sin }) + } +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +fn embedding(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Embedding::new(embeddings, cfg.hidden_size)) +} + +struct CausalSelfAttention { + qkv_proj: TensorParallelColumnLinear, + o_proj: TensorParallelRowLinear, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + attn: PagedAttention, + cos_sin_cache: Cache, +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, input_positions: &[Vec]) -> Result { + let (b_sz, _, seq_len, _hidden_size) = x.dims4()?; + let mut embeds = Vec::new(); + for (b, seqlen_offset) in zip(0..b_sz, input_positions) { + let cos = self + .cos_sin_cache + .cos + .narrow(0, seqlen_offset[0], seq_len)?; + let sin = self + .cos_sin_cache + .sin + .narrow(0, seqlen_offset[0], seq_len)?; + let x_b = x.narrow(0, b, 1)?; + let embed = candle_nn::rotary_emb::rope(&x_b, &cos, &sin).unwrap(); + embeds.push(embed); + } + Tensor::cat(&embeds, 0) + } + + fn forward( + &self, + x: &Tensor, + attention_mask: Option<&Tensor>, + input_positions: &[Vec], + cache: Option<(&Tensor, &Tensor)>, + input_metadata: &InputMetadata, + ) -> Result { + let (b_sz, seq_len, _) = x.dims3()?; + let qkv = self.qkv_proj.forward(x)?; + let hidden_size = self.num_attention_heads * self.head_dim; + + let q = qkv.i((.., .., ..self.num_attention_heads * self.head_dim))?; + let k = qkv.i(( + .., + .., + self.num_attention_heads * self.head_dim + ..self.num_attention_heads * self.head_dim + + self.num_key_value_heads * self.head_dim, + ))?; + let v = qkv.i(( + .., + .., + self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim.., + ))?; + + let (q, k, v) = if seq_len == 1 { + //no need transpose for seq_len == 1, change reshape dim + let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?; + let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?; + let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?; + (q, k, v) + } else { + let q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let v = v + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)?; + (q, k, v.contiguous()?) + }; + + let q = self.apply_rotary_emb(&q, input_positions)?; + let k = self.apply_rotary_emb(&k, input_positions)?; + + let y = self.attn.forward( + &q, + &k, + &v, + attention_mask, + cache.map(|(k_, _)| k_.clone()), + cache.map(|(_, v_)| v_.clone()), + input_metadata, + None, + )?; + + let y = if attention_mask.is_some() { + y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])? + } else { + y.reshape(&[b_sz, seq_len, hidden_size])? + }; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn load( + vb: VarBuilder, + cfg: &Config, + dtype: DType, + device: &Device, + comm: Rc, + ) -> Result { + let qkv_proj = TensorParallelColumnLinear::load_multi( + vb.clone(), + &["q_proj", "k_proj", "v_proj"], + comm.clone(), + )?; + let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?; + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + let attention_heads = cfg.num_attention_heads / comm.world_size(); + let kv_heads = cfg.num_key_value_heads / comm.world_size(); + Ok(Self { + qkv_proj, + o_proj, + num_attention_heads: attention_heads, + num_key_value_heads: kv_heads, + head_dim, + attn: PagedAttention::new( + attention_heads, + head_dim, + 1. / ((head_dim as f32).sqrt()), + Some(kv_heads), + None, + vb.device().clone(), + None, + )?, + cos_sin_cache: Cache::new(dtype, cfg, device)?, + }) + } +} + +struct Mlp { + c_fc1: TensorParallelColumnLinear, + c_fc2: TensorParallelColumnLinear, + c_proj: TensorParallelRowLinear, +} + +impl Mlp { + fn forward(&self, x: &Tensor) -> Result { + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, comm: Rc) -> Result { + let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; + let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; + let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + }) + } +} +fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?; + Ok(RmsNorm::new(weight, eps)) +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, +} + +impl Block { + fn forward( + &self, + x: &Tensor, + attention_mask: Option<&Tensor>, + input_positions: &[Vec], + cache: Option<(&Tensor, &Tensor)>, + input_metadata: &InputMetadata, + ) -> Result { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self + .attn + .forward(&x, attention_mask, input_positions, cache, input_metadata)? + + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load( + vb: VarBuilder, + cfg: &Config, + dtype: DType, + device: &Device, + comm: Rc, + ) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg, dtype, device, comm.clone())?; + let mlp = Mlp::load(vb.pp("mlp"), comm.clone())?; + let rms_1 = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?; + let rms_2 = rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?; + Ok(Self { + rms_1, + attn, + rms_2, + mlp, + }) + } +} + +pub struct LlamaMulti { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, + cfg: Config, + dtype: DType, + device: Device, +} + +impl LlamaMulti { + fn prepare_decoder_attention_mask(&self, b_size: usize, tgt_len: usize) -> Result { + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + mask.expand((b_size, 1, tgt_len, tgt_len))? + .contiguous()? + .to_dtype(self.dtype) + } + + pub fn forward( + &self, + x: &Tensor, + input_positions: &[Vec], + kv_caches: Option<&Vec<(Tensor, Tensor)>>, + input_metadata: &InputMetadata, + ) -> Result { + let (_b_sz, seq_len) = x.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(_b_sz, seq_len)?; + Some(mask) + }; + let mut x = self.wte.forward(x)?; + if let Some(kv_caches) = kv_caches { + for ((k_cache, v_cache), block) in zip(kv_caches.iter(), &self.blocks) { + x = block.forward( + &x, + attention_mask.as_ref(), + input_positions, + Some((k_cache, v_cache)), + input_metadata, + )?; + } + } else { + for block in &self.blocks { + x = block.forward( + &x, + attention_mask.as_ref(), + input_positions, + None, + input_metadata, + )?; + } + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load( + vb: VarBuilder, + cfg: &Config, + dtype: DType, + device: &Device, + comm: Rc, + ) -> Result { + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let ln_f = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.num_hidden_layers) + .map(|i| { + Block::load( + vb.pp(&format!("model.layers.{i}")), + cfg, + dtype, + device, + comm.clone(), + ) + .unwrap() + }) + .collect(); + + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + cfg: cfg.clone(), + dtype, + device: device.clone(), + }) + } + + pub fn get_config(&self) -> &Config { + &self.cfg + } +} diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index f0af311..d47dc2c 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -256,7 +256,7 @@ impl Attention { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -351,7 +351,7 @@ impl DecoderLayer { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -433,7 +433,7 @@ impl Mistral { } pub fn forward( - &mut self, + &self, input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -448,7 +448,7 @@ impl Mistral { }; let mut xs = self.embed_tokens.forward(input_ids)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { xs = layer.forward( &xs, attention_mask.as_ref(), @@ -458,7 +458,7 @@ impl Mistral { )? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { xs = layer.forward( &xs, attention_mask.as_ref(), diff --git a/src/openai/models/mod.rs b/src/openai/models/mod.rs index 16ee7c9..03e703a 100644 --- a/src/openai/models/mod.rs +++ b/src/openai/models/mod.rs @@ -1,6 +1,8 @@ pub mod gemma; pub mod linear; pub mod llama; +#[cfg(feature = "nccl")] +pub mod llama_multi; pub mod mistral; pub mod phi2; pub mod phi3; diff --git a/src/openai/models/phi2.rs b/src/openai/models/phi2.rs index 4d6679b..ec7d9c3 100644 --- a/src/openai/models/phi2.rs +++ b/src/openai/models/phi2.rs @@ -245,7 +245,7 @@ impl Attention { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -336,7 +336,7 @@ impl DecoderLayer { } fn forward( - &mut self, + &self, xs: &Tensor, mask: Option<&Tensor>, input_positions: &[Vec], @@ -407,7 +407,7 @@ impl Phi2 { } pub fn forward( - &mut self, + &self, xs: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -422,7 +422,7 @@ impl Phi2 { Some(mask) }; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { xs = layer.forward( &xs, attention_mask.as_ref(), @@ -432,7 +432,7 @@ impl Phi2 { )? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { xs = layer.forward( &xs, attention_mask.as_ref(), diff --git a/src/openai/models/phi3.rs b/src/openai/models/phi3.rs index 08f5dbe..6f70f77 100644 --- a/src/openai/models/phi3.rs +++ b/src/openai/models/phi3.rs @@ -289,7 +289,7 @@ impl Attention { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -439,7 +439,7 @@ impl DecoderLayer { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -510,7 +510,7 @@ impl Phi { } pub fn forward( - &mut self, + &self, input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -526,7 +526,7 @@ impl Phi { let mut xs = self.embed_tokens.forward(input_ids)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { xs = layer.forward( &xs, attention_mask.as_ref(), @@ -536,7 +536,7 @@ impl Phi { )? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { xs = layer.forward( &xs, attention_mask.as_ref(), diff --git a/src/openai/models/quantized_llama.rs b/src/openai/models/quantized_llama.rs index 066f6a8..f8a748f 100644 --- a/src/openai/models/quantized_llama.rs +++ b/src/openai/models/quantized_llama.rs @@ -168,7 +168,7 @@ impl LayerWeights { } fn forward_attn( - &mut self, + &self, x: &Tensor, mask: Option<&Tensor>, input_positions: &[Vec], @@ -520,7 +520,7 @@ impl GGUFLLaMa { } pub fn forward( - &mut self, + &self, x: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -535,7 +535,7 @@ impl GGUFLLaMa { let mut layer_in = self.tok_embeddings.forward(x)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { let x = layer_in; let residual = &x; let x = layer.attention_norm.forward(&x)?; @@ -556,7 +556,7 @@ impl GGUFLLaMa { layer_in = x } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { let x = layer_in; let residual = &x; let x = layer.attention_norm.forward(&x)?; diff --git a/src/openai/models/quantized_phi3.rs b/src/openai/models/quantized_phi3.rs index 5e13b97..2cda380 100644 --- a/src/openai/models/quantized_phi3.rs +++ b/src/openai/models/quantized_phi3.rs @@ -97,7 +97,7 @@ impl LayerWeights { } fn forward_attn( - &mut self, + &self, x: &Tensor, mask: Option<&Tensor>, input_positions: &[Vec], @@ -342,7 +342,7 @@ impl GGUFPhi3 { } pub fn forward( - &mut self, + &self, xs: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -357,7 +357,7 @@ impl GGUFPhi3 { let mut xs = self.tok_embeddings.forward(xs)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { let residual = &xs; let ys = xs.apply(&layer.attn_norm)?; let ys = layer.forward_attn( @@ -374,7 +374,7 @@ impl GGUFPhi3 { xs = (ys + residual)? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { let residual = &xs; let ys = xs.apply(&layer.attn_norm)?; let ys = layer.forward_attn( diff --git a/src/openai/models/qwen2.rs b/src/openai/models/qwen2.rs index 43b70fa..3dbb3c9 100644 --- a/src/openai/models/qwen2.rs +++ b/src/openai/models/qwen2.rs @@ -259,7 +259,7 @@ impl Attention { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -353,7 +353,7 @@ impl DecoderLayer { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -447,7 +447,7 @@ impl Qwen2 { } pub fn forward( - &mut self, + &self, input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -463,7 +463,7 @@ impl Qwen2 { let mut xs = self.embed_tokens.forward(input_ids)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { xs = layer.forward( &xs, attention_mask.as_ref(), @@ -473,7 +473,7 @@ impl Qwen2 { )? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { xs = layer.forward( &xs, attention_mask.as_ref(), diff --git a/src/openai/models/stable_lm.rs b/src/openai/models/stable_lm.rs index f07688d..b5f1f44 100644 --- a/src/openai/models/stable_lm.rs +++ b/src/openai/models/stable_lm.rs @@ -268,7 +268,7 @@ impl Attention { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -365,7 +365,7 @@ impl DecoderLayer { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -437,7 +437,7 @@ impl StableLM { } pub fn forward( - &mut self, + &self, input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -452,7 +452,7 @@ impl StableLM { }; let mut xs = self.embed_tokens.forward(input_ids)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { xs = layer.forward( &xs, attention_mask.as_ref(), @@ -462,7 +462,7 @@ impl StableLM { )? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { xs = layer.forward( &xs, attention_mask.as_ref(), diff --git a/src/openai/models/yi.rs b/src/openai/models/yi.rs index 92173bb..4377fe8 100644 --- a/src/openai/models/yi.rs +++ b/src/openai/models/yi.rs @@ -255,7 +255,7 @@ impl Attention { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -349,7 +349,7 @@ impl DecoderLayer { } fn forward( - &mut self, + &self, xs: &Tensor, attention_mask: Option<&Tensor>, input_positions: &[Vec], @@ -421,7 +421,7 @@ impl Yi { } pub fn forward( - &mut self, + &self, input_ids: &Tensor, input_positions: &[Vec], kv_caches: Option<&Vec<(Tensor, Tensor)>>, @@ -436,7 +436,7 @@ impl Yi { }; let mut xs = self.embed_tokens.forward(input_ids)?; if let Some(kv_caches) = kv_caches { - for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter()) { xs = layer.forward( &xs, attention_mask.as_ref(), @@ -446,7 +446,7 @@ impl Yi { )? } } else { - for layer in self.layers.iter_mut() { + for layer in self.layers.iter() { xs = layer.forward( &xs, attention_mask.as_ref(), diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index f80c346..1c10f84 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -1,10 +1,4 @@ -use std::{ - collections::{HashMap, VecDeque}, - iter::zip, - sync::Arc, -}; - -use super::{ModulePipeline, _make_tensor_with_pad}; +use super::{DefaultPipeline, _make_tensor_with_pad}; use crate::openai::streaming::ChatResponse; use crate::scheduler::Scheduler; use crate::{ @@ -25,10 +19,15 @@ use crate::{ }, try_api, }; -use candle_core::Tensor; +use candle_core::{Device, Tensor}; use either::Either; use flume::Sender; use std::time::SystemTime; +use std::{ + collections::{HashMap, VecDeque}, + iter::zip, + sync::Arc, +}; use tokenizers::Encoding; use tokio::sync::Mutex; use tokio::sync::Notify; @@ -42,7 +41,7 @@ struct PreparedInputs { const _PAD_SLOT_ID: i64 = -1; pub struct LLMEngine { - pipelines: HashMap, CacheEngine)>, + pipelines: HashMap, CacheEngine)>, scheduler: Scheduler, seq_id: usize, cache_config: CacheConfig, @@ -55,7 +54,7 @@ pub struct LLMEngine { impl LLMEngine { pub fn new( - pipelines: HashMap, CacheEngine)>, + pipelines: HashMap, CacheEngine)>, scheduler_config: SchedulerConfig, cache_config: &CacheConfig, config: &Config, @@ -137,14 +136,14 @@ impl LLMEngine { Ok(engine_clone) } - pub fn get_pipeline(&self, rank: usize) -> Option<&(Box, CacheEngine)> { + pub fn get_pipeline(&self, rank: usize) -> Option<&(Box, CacheEngine)> { self.pipelines.get(&rank) } pub fn get_mut_pipeline( &mut self, rank: usize, - ) -> Option<&mut (Box, CacheEngine)> { + ) -> Option<&mut (Box, CacheEngine)> { self.pipelines.get_mut(&rank) } @@ -183,7 +182,6 @@ impl LLMEngine { let mut responses = HashMap::, ChatCompletionUsageResponse)>::new(); let mut prompt_finish_times = HashMap::::new(); - // let mut prompt_finish_time = SystemTime::now(); while self.scheduler.has_unfinished_sequences() { let scheduler_outputs = self.scheduler.schedule(); if !scheduler_outputs.ignored_seq_groups.is_empty() { @@ -193,76 +191,74 @@ impl LLMEngine { self.execute_scheduler_ops(&scheduler_outputs, 0).unwrap(); let scheduled: &VecDeque> = &scheduler_outputs.scheduled; - // for group in scheduled.iter() { let seqs = scheduled[0].get_seqs(); - let PreparedInputs { - tokens, - positions, - metadata, - } = if seqs.values().nth(0).unwrap().deref().is_prompt() { - self.prepare_prompt(scheduled, 0) - } else { - self.prepare_decode(scheduled, 0) - } - .unwrap(); - use rayon::iter::IntoParallelRefMutIterator; + #[cfg(feature = "nccl")] + use rayon::iter::IntoParallelRefIterator; + #[cfg(feature = "nccl")] use rayon::iter::ParallelIterator; - let vec_logits: Vec = self + #[cfg(feature = "nccl")] + let vec_logits: HashMap = self .pipelines - .par_iter_mut() + .par_iter() .map(|(rank, (pipeline, cache_engine))| { let device = pipeline.device(); - let metadata_ = if *rank == 0 { - &metadata + let PreparedInputs { + tokens, + positions, + metadata, + } = if seqs.values().nth(0).unwrap().deref().is_prompt() { + self.prepare_prompt(scheduled, device) } else { - let context_lens = if metadata.context_lens.is_some() { - Some( - metadata - .context_lens - .as_ref() - .unwrap() - .to_device(device) - .unwrap(), - ) - } else { - metadata.context_lens.clone() - }; - let block_tables = if metadata.block_tables.is_some() { - Some( - metadata - .block_tables - .as_ref() - .unwrap() - .to_device(device) - .unwrap(), + self.prepare_decode(scheduled, device) + } + .unwrap(); + ( + *rank, + pipeline + .forward( + tokens, + &positions, + Some(&*cache_engine.get_kv_cache()), + &metadata, ) - } else { - metadata.block_tables.clone() - }; + .unwrap(), + ) + }) + .collect(); - &InputMetadata { - //for other rank, some tensors need to be moved - slot_mapping: metadata.slot_mapping.to_device(device).unwrap(), - context_lens, - block_tables, - kv_cache_dtype: metadata.kv_cache_dtype.clone(), - prompt_lens: metadata.prompt_lens.clone(), - ..metadata - } - }; - pipeline - .forward( - tokens.clone(), - &positions, - Some(&*cache_engine.get_kv_cache()), - metadata_, - ) - .unwrap() + #[cfg(not(feature = "nccl"))] + let vec_logits: HashMap = self + .pipelines + .iter() + .map(|(rank, (pipeline, cache_engine))| { + let device = pipeline.device(); + let PreparedInputs { + tokens, + positions, + metadata, + } = if seqs.values().nth(0).unwrap().deref().is_prompt() { + self.prepare_prompt(scheduled, device) + } else { + self.prepare_decode(scheduled, device) + } + .unwrap(); + ( + *rank, + pipeline + .forward( + tokens, + &positions, + Some(&*cache_engine.get_kv_cache()), + &metadata, + ) + .unwrap(), + ) }) .collect(); + let pipeline = self.get_mut_pipeline(0).unwrap().0.as_mut(); - let results = pipeline.sample(&vec_logits[0], scheduled).unwrap(); + let results = pipeline.sample(&vec_logits[&0], scheduled).unwrap(); for (result_, group) in zip(results, scheduled) { match result_ { @@ -419,7 +415,7 @@ impl LLMEngine { fn prepare_prompt( &self, groups: &VecDeque>, - rank: usize, + device: &Device, ) -> Result { let mut prompt_lens = Vec::new(); let mut input_tokens = Vec::new(); @@ -484,7 +480,6 @@ impl LLMEngine { slot_mappings.push(slot_mapping); } } - let device = self.get_pipeline(rank).unwrap().0.device(); let max_prompt_len = prompt_lens.iter().max().unwrap(); let input_tokens = _make_tensor_with_pad( @@ -517,7 +512,7 @@ impl LLMEngine { fn prepare_decode( &self, groups: &VecDeque>, - rank: usize, + device: &Device, ) -> Result { let mut input_tokens = Vec::new(); let mut input_positions = Vec::new(); @@ -573,7 +568,6 @@ impl LLMEngine { } } } - let device = self.get_pipeline(rank).unwrap().0.device(); let input_tokens = _make_tensor_with_pad( input_tokens diff --git a/src/openai/pipelines/mod.rs b/src/openai/pipelines/mod.rs index 7793bf3..a4efd5f 100644 --- a/src/openai/pipelines/mod.rs +++ b/src/openai/pipelines/mod.rs @@ -1,59 +1,16 @@ +use super::responses::APIError; use crate::openai::sampling_params::Logprobs; -use candle_core::{DType, Device, Tensor, WithDType}; +use crate::try_api; +use candle_core::{Device, Tensor, WithDType}; use dirs; use either::Either; -use std::{env, fs, path::PathBuf, sync::Arc}; - -use crate::{paged_attention::input_metadata::InputMetadata, try_api}; - -use super::{conversation::Conversation, models::Config, responses::APIError, PipelineConfig}; -use candle_examples::token_output_stream::TokenOutputStream; +use std::{env, fs}; /// The LLMEngine is effectively a wrapper around a ModulePipeline. It contains a Scheduler and a CacheEngine /// which are used to scheduler and manage the cache during generation requests, respectively. pub mod llm_engine; pub mod pipeline; -use crate::scheduler::sequence::SequenceGroup; type TokenOrFinishReason = Either; -#[cfg(feature = "nccl")] -pub use cudarc::nccl::safe::Comm; -use std::collections::VecDeque; -pub use std::rc::Rc; - -pub trait ModulePipeline: Send + Sync { - fn forward( - &mut self, - input_tokens: Tensor, - input_positions: &[Vec], - kv_cache: Option<&Vec<(Tensor, Tensor)>>, - input_metadata: &InputMetadata, - ) -> Result; - - fn sample( - &mut self, - logits: &Tensor, - groups: &VecDeque>, - ) -> Result, APIError>; - - fn sample_batch( - &mut self, - logits: &Tensor, - groups: &VecDeque>, - ) -> Result, APIError>; - - fn name(&self) -> &str; - - fn tokenizer(&self) -> &TokenOutputStream; - - fn get_conversation(&mut self, with_history: bool) -> &mut dyn Conversation; - - fn get_model_config(&self) -> Config; - - fn get_dtype(&self) -> DType; - - fn device(&self) -> &Device; - - fn reset_decoder(&mut self) -> Option; -} +use crate::openai::pipelines::pipeline::DefaultPipeline; fn _make_tensor_with_pad( x: Vec>, @@ -98,28 +55,3 @@ pub(crate) fn get_token( } }) } - -pub trait ModelPaths { - fn get_weight_filenames(&self) -> &Vec; - fn get_config_filename(&self) -> &PathBuf; - fn get_tokenizer_filename(&self) -> &PathBuf; -} - -pub trait ModelLoader { - fn download_model( - &self, - model_id: String, - revision: Option, - hf_token: Option, - hf_token_path: Option, - ) -> Result, APIError>; - - fn load_model( - &self, - paths: &Box, - dtype: DType, - quant: &Option, - device: Device, - #[cfg(feature = "nccl")] comm: Option>, - ) -> Result<(Box, PipelineConfig), APIError>; -} diff --git a/src/openai/pipelines/pipeline.rs b/src/openai/pipelines/pipeline.rs index a4ed7f1..e3caaad 100644 --- a/src/openai/pipelines/pipeline.rs +++ b/src/openai/pipelines/pipeline.rs @@ -1,4 +1,4 @@ -use super::{get_token, ModelLoader, ModelPaths, ModulePipeline, TokenOrFinishReason}; +use super::{get_token, TokenOrFinishReason}; use crate::openai::logits_processor::{LogitsProcessor, Sampling}; use crate::openai::models::TokenID; use crate::openai::sampling_params::{Logprobs, TopLogprob}; @@ -30,12 +30,14 @@ use crate::{ paged_attention::input_metadata::InputMetadata, try_api, SpecificConfig, }; + +#[cfg(feature = "nccl")] +use crate::openai::models::llama_multi::LlamaMulti; + use candle_core::quantized::gguf_file; use candle_core::{DType, Device, IndexOp, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; -#[cfg(feature = "nccl")] -pub use cudarc::nccl::safe::Comm; use either::Either; use either::Either::{Left, Right}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; @@ -59,6 +61,8 @@ enum LLMModel { StableLM(StableLM), LlamaGGUF(GGUFLLaMa), Phi3GGUF(GGUFPhi3), + #[cfg(feature = "nccl")] + LlamaMulti(LlamaMulti), } /// top-p, multinomial, and argmax sampling are implemented. Beam search is not implemented. pub struct DefaultPipeline { @@ -71,6 +75,7 @@ pub struct DefaultPipeline { dtype: DType, device: Device, stop_token_ids: Vec, + rank: usize, } pub struct DefaultLoader { @@ -78,21 +83,22 @@ pub struct DefaultLoader { name: String, } -pub struct DefaultModelPaths

{ - pub tokenizer_filename: P, - pub config_filename: P, - pub filenames: Vec

, +#[derive(Debug, Clone)] +pub struct DefaultModelPaths { + pub tokenizer_filename: PathBuf, + pub config_filename: PathBuf, + pub filenames: Vec, } -impl ModelPaths for DefaultModelPaths { - fn get_config_filename(&self) -> &PathBuf { - &self.config_filename +impl DefaultModelPaths { + fn get_config_filename(&self) -> PathBuf { + self.config_filename.clone() } - fn get_tokenizer_filename(&self) -> &PathBuf { - &self.tokenizer_filename + fn get_tokenizer_filename(&self) -> PathBuf { + self.tokenizer_filename.clone() } - fn get_weight_filenames(&self) -> &Vec { - &self.filenames + fn get_weight_filenames(&self) -> Vec { + self.filenames.clone() } } @@ -102,14 +108,14 @@ impl DefaultLoader { } } -impl ModelLoader for DefaultLoader { - fn download_model( +impl DefaultLoader { + pub fn download_model( &self, model_id: String, revision: Option, hf_token: Option, hf_token_path: Option, - ) -> Result, APIError> { + ) -> Result { let api = try_api!(ApiBuilder::new() .with_progress(true) .with_token(Some(get_token(hf_token, hf_token_path)?)) @@ -132,27 +138,26 @@ impl ModelLoader for DefaultLoader { filenames.push(filename); } - Ok(Box::new(DefaultModelPaths { + Ok(DefaultModelPaths { tokenizer_filename, config_filename, filenames, - })) + }) } - fn load_model( + pub async fn load_model( &self, - paths: &Box, + paths: DefaultModelPaths, dtype: DType, quant: &Option, - device: Device, - #[cfg(feature = "nccl")] comm: Option>, - ) -> Result<(Box, PipelineConfig), APIError> { + device_ids: Vec, + ) -> Result<(Vec>, PipelineConfig), APIError> { let specific_args = self.config.clone(); - let mut stop_token_ids = Vec::::new(); - let (model, config, tokenizer, sep_style) = if quant.is_some() + let (models, devices, config, sep_style) = if quant.is_some() && matches!(quant.as_ref().unwrap().as_str(), "ggml" | "gguf") { + let device = crate::new_device(device_ids[0]).unwrap(); let path = paths.get_weight_filenames()[0].clone(); println!( "Loading quantized {} model from file {}", @@ -187,11 +192,7 @@ impl ModelLoader for DefaultLoader { } _ => panic!("Model not supported!"), }; - let tokenizer_ = Tokenizer::from_file(paths.get_tokenizer_filename()) - .map_err(|x| APIError::new(x.to_string()))?; - let tokenizer = - candle_examples::token_output_stream::TokenOutputStream::new(tokenizer_); - (model, config.to_owned(), tokenizer, sep_style) + (vec![model], vec![device], config.to_owned(), sep_style) } else { let config = match self.name.as_str() { "llama" | "llama3" => { @@ -249,81 +250,126 @@ impl ModelLoader for DefaultLoader { println!("Model {:?}", config); println!("Loading {} model.", self.name); + #[cfg(feature = "nccl")] + let (models, devices, sep_style) = if device_ids.len() > 1 { + use cudarc::nccl::safe::{Comm, Id}; + let id = Id::new().unwrap(); + let results: Vec<_> = device_ids + .par_iter() + .enumerate() + .map(|(rank, dev_id)| { + println!( + "Loading partial model on device rank {} (ordinal {})", + rank, *dev_id + ); + let pathes: Vec = paths.get_weight_filenames(); + let device = crate::new_device(*dev_id).unwrap(); + let comm = Rc::new( + Comm::from_rank( + device.as_cuda_device().unwrap().cuda_device(), + rank, + device_ids.len(), + id, + ) + .unwrap(), + ); + let vb = unsafe { + candle_nn::var_builder::ShardedSafeTensors::var_builder( + &pathes, dtype, &device, + ) + .unwrap() + }; + match self.name.as_str() { + "llama" | "llama3" => Ok(( + device.clone(), + LLMModel::LlamaMulti(try_api!(LlamaMulti::load( + vb, &config, dtype, &device, comm + ))), + )), + _ => panic!("Model not supported!"), + } + }) + .collect(); + + // Separate devices and models from the results + let mut devices = Vec::new(); + let mut models = Vec::new(); + for result in results { + match result { + Ok((device, model)) => { + devices.push(device); + models.push(model); + } + Err(e) => { + return Err(e.into()); + } + } + } - let vb = match unsafe { - VarBuilder::from_mmaped_safetensors(paths.get_weight_filenames(), dtype, &device) - } { - Ok(vb_) => vb_, - _ => panic!("Load model weights failed!"), + (models, devices, SeparatorStyle::Llama3) + } else { + panic!("You've enabled nccl feature for multi-gpu inference but only one device was given!"); }; - let (model, sep_style) = match self.name.as_str() { - "llama" => ( - LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device, comm))), - SeparatorStyle::Llama, - ), - "llama3" => ( - LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device, comm))), - SeparatorStyle::Llama3, - ), - "phi2" => ( - LLMModel::Phi2(try_api!(Phi2::new(vb, &config, dtype, &device))), - SeparatorStyle::Phi, - ), - "phi3" => ( - LLMModel::Phi3(try_api!(Phi::new(vb, &config, dtype, &device))), - SeparatorStyle::Phi, - ), - "qwen2" => ( - LLMModel::Qwen2(try_api!(Qwen2::new(vb, &config, dtype, &device))), - SeparatorStyle::Qwen2, - ), - "gemma" => ( - LLMModel::Gemma(try_api!(Gemma::new(vb, &config, dtype, &device))), - SeparatorStyle::Gemma, - ), - "mistral" => ( - LLMModel::Mistral(try_api!(Mistral::new(vb, &config, dtype, &device))), - SeparatorStyle::Mistral, - ), - "yi" => ( - LLMModel::Yi(try_api!(Yi::new(vb, &config, dtype, &device))), - SeparatorStyle::Yi, - ), - "stablelm" => ( - LLMModel::StableLM(try_api!(StableLM::new(vb, &config, dtype, &device))), - SeparatorStyle::StableLM, - ), - _ => panic!("Model not supported!"), - }; - match &config.eos_token_id { - //eos_token defined in the config - TokenID(Either::Left(eos_token)) => { - if let Some(tk) = eos_token { - stop_token_ids.push(*tk); - } - } - TokenID(Either::Right(eos_token_list)) => { - if let Some(tks) = eos_token_list { - stop_token_ids.extend(tks) - } - } - } + #[cfg(not(feature = "nccl"))] + let (models, devices, sep_style) = if device_ids.len() < 2 { + let device = crate::new_device(device_ids[0]).unwrap(); + let vb = match unsafe { + VarBuilder::from_mmaped_safetensors( + &paths.get_weight_filenames(), + dtype, + &device, + ) + } { + Ok(vb_) => vb_, + _ => panic!("Load model weights failed!"), + }; - let tokenizer_ = Tokenizer::from_file(paths.get_tokenizer_filename()) - .map_err(|x| APIError::new(x.to_string()))?; - let tokenizer = - candle_examples::token_output_stream::TokenOutputStream::new(tokenizer_); + let (model, sep) = match self.name.as_str() { + "llama" => ( + LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device))), + SeparatorStyle::Llama, + ), + "llama3" => ( + LLMModel::Llama(try_api!(Llama::load(vb, &config, dtype, &device))), + SeparatorStyle::Llama3, + ), + "phi2" => ( + LLMModel::Phi2(try_api!(Phi2::new(vb, &config, dtype, &device))), + SeparatorStyle::Phi, + ), + "phi3" => ( + LLMModel::Phi3(try_api!(Phi::new(vb, &config, dtype, &device))), + SeparatorStyle::Phi, + ), + "qwen2" => ( + LLMModel::Qwen2(try_api!(Qwen2::new(vb, &config, dtype, &device))), + SeparatorStyle::Qwen2, + ), + "gemma" => ( + LLMModel::Gemma(try_api!(Gemma::new(vb, &config, dtype, &device))), + SeparatorStyle::Gemma, + ), + "mistral" => ( + LLMModel::Mistral(try_api!(Mistral::new(vb, &config, dtype, &device))), + SeparatorStyle::Mistral, + ), + "yi" => ( + LLMModel::Yi(try_api!(Yi::new(vb, &config, dtype, &device))), + SeparatorStyle::Yi, + ), + "stablelm" => ( + LLMModel::StableLM(try_api!(StableLM::new(vb, &config, dtype, &device))), + SeparatorStyle::StableLM, + ), + _ => panic!("Model not supported!"), + }; + (vec![model], vec![device], sep) + } else { + panic!("You've provided multiple devices for inference but nccl feature is not enalbed!"); + }; - //custom stop tokens - if let Some(custom_stop) = &config.custom_stop_tokens { - for stop in custom_stop { - if let Some(token) = tokenizer.get_token(stop) { - stop_token_ids.push(token) - }; - } - } - (model, config, tokenizer, sep_style) + (models, devices, config, sep_style) }; println!("Done loading."); @@ -343,74 +389,106 @@ impl ModelLoader for DefaultLoader { }; println!("{:?}", pipeline_config); - - if stop_token_ids.is_empty() { - //if no eos_token defined in the config, use default - if let Some(token) = tokenizer.get_token("<|endoftext|>") { - stop_token_ids.push(token); - } - if let Some(token) = tokenizer.get_token("<|end|>") { - stop_token_ids.push(token); - } else if stop_token_ids.is_empty() { - let token = tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap_or(0); - stop_token_ids.push(token); - } - } - println!("{:?}", specific_args); - let logits_processor = { - let temperature = f64::from(pipeline_config.temperature); - let sampling = if temperature <= 0. { - Sampling::ArgMax - } else { - match (specific_args.top_k, specific_args.top_p) { - (None, None) => Sampling::All { temperature }, - (Some(k), None) => Sampling::TopK { k, temperature }, - (None, Some(p)) => Sampling::TopP { p, temperature }, - (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + let pipelines = models + .into_iter() + .enumerate() + .map(|(rank, model)| { + let logits_processor = { + let temperature = f64::from(pipeline_config.temperature); + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (specific_args.top_k, specific_args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(SAMPLING_SEED, sampling) + }; + let tokenizer_ = Tokenizer::from_file(paths.get_tokenizer_filename()) + .map_err(|x| APIError::new(x.to_string())) + .unwrap(); + let tokenizer = + candle_examples::token_output_stream::TokenOutputStream::new(tokenizer_); + + let mut stop_token_ids = Vec::::new(); + match &config.eos_token_id { + //eos_token defined in the config + TokenID(Either::Left(eos_token)) => { + if let Some(tk) = eos_token { + stop_token_ids.push(*tk); + } + } + TokenID(Either::Right(eos_token_list)) => { + if let Some(tks) = eos_token_list { + stop_token_ids.extend(tks) + } + } + } + //custom stop tokens + if let Some(custom_stop) = &config.custom_stop_tokens { + for stop in custom_stop { + if let Some(token) = tokenizer.get_token(stop) { + stop_token_ids.push(token) + }; + } } - }; - LogitsProcessor::from_sampling(SAMPLING_SEED, sampling) - }; - Ok(( - Box::new(DefaultPipeline { - model, - args: specific_args, - tokenizer, - logits_processor, - conversation: DefaultConversation::new( - self.name.to_string(), - "[INST] <>\n{}\n<>\n\n [/INST]".to_string(), - Vec::default(), - 0, - sep_style, - "".to_string(), - stop_token_ids.clone(), - ("user".to_string(), "assistant".to_string()), - DefaultConversationSeparators { - sep: " ".to_string(), - sep2: Some(" ".to_string()), - }, - ), - name: self.name.clone(), - dtype, - device: device.clone(), - stop_token_ids, - }), - pipeline_config, - )) + if stop_token_ids.is_empty() { + //if no eos_token defined in the config, use default + if let Some(token) = tokenizer.get_token("<|endoftext|>") { + stop_token_ids.push(token); + } + if let Some(token) = tokenizer.get_token("<|end|>") { + stop_token_ids.push(token); + } else if stop_token_ids.is_empty() { + let token = tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap_or(0); + stop_token_ids.push(token); + } + } + Box::new(DefaultPipeline { + model, + args: specific_args.clone(), + tokenizer, + logits_processor, + conversation: DefaultConversation::new( + self.name.to_string(), + "[INST] <>\n{}\n<>\n\n [/INST]".to_string(), + Vec::default(), + 0, + sep_style.clone(), + "".to_string(), + stop_token_ids.clone(), + ("user".to_string(), "assistant".to_string()), + DefaultConversationSeparators { + sep: " ".to_string(), + sep2: Some(" ".to_string()), + }, + ), + name: self.name.clone(), + dtype, + device: devices[rank].clone(), + stop_token_ids, + rank, + }) + }) + .collect(); + + Ok((pipelines, pipeline_config)) } } -impl ModulePipeline for DefaultPipeline { - fn forward( - &mut self, +impl DefaultPipeline { + pub fn forward( + &self, input_tokens: Tensor, input_positions: &[Vec], kv_cache: Option<&Vec<(Tensor, Tensor)>>, - mut input_metadata: &InputMetadata, + input_metadata: &InputMetadata, ) -> Result { let input_tokens = if input_tokens.shape().dims().len() < 2 { input_tokens @@ -420,91 +498,45 @@ impl ModulePipeline for DefaultPipeline { input_tokens }; - match &mut self.model { + match &self.model { LLMModel::Llama(llama) => llama - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) + .map_err(APIError::from), + #[cfg(feature = "nccl")] + LLMModel::LlamaMulti(llama) => llama + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::Phi2(phi) => phi - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::Phi3(phi) => phi - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::Qwen2(qwen2) => qwen2 - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::Gemma(gemma) => gemma - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::Mistral(mistral) => mistral - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::Yi(yi) => yi - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::StableLM(stablelm) => stablelm - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::Phi3GGUF(phi3) => phi3 - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), LLMModel::LlamaGGUF(llama) => llama - .forward( - &input_tokens, - input_positions, - kv_cache, - &mut input_metadata, - ) + .forward(&input_tokens, input_positions, kv_cache, &input_metadata) .map_err(APIError::from), } } - fn sample( + pub fn sample( &mut self, logits: &Tensor, groups: &VecDeque>, @@ -606,7 +638,7 @@ impl ModulePipeline for DefaultPipeline { Ok(result) } - fn sample_batch( + pub fn sample_batch( &mut self, logits: &Tensor, groups: &VecDeque>, @@ -662,24 +694,26 @@ impl ModulePipeline for DefaultPipeline { Ok(result) } - fn name(&self) -> &str { + pub fn name(&self) -> &str { &self.name } - fn tokenizer(&self) -> &TokenOutputStream { + pub fn tokenizer(&self) -> &TokenOutputStream { &self.tokenizer } - fn get_conversation(&mut self, with_history: bool) -> &mut dyn Conversation { + pub fn get_conversation(&mut self, with_history: bool) -> &mut dyn Conversation { if !with_history { self.conversation.clear_message(); } &mut self.conversation } - fn get_model_config(&self) -> Config { + pub fn get_model_config(&self) -> Config { match &self.model { LLMModel::Llama(llama) => llama.get_config().clone(), + #[cfg(feature = "nccl")] + LLMModel::LlamaMulti(llama) => llama.get_config().clone(), LLMModel::Phi2(phi) => phi.get_config().clone(), LLMModel::Phi3(phi) => phi.get_config().clone(), LLMModel::Qwen2(qwen2) => qwen2.get_config().clone(), @@ -692,19 +726,23 @@ impl ModulePipeline for DefaultPipeline { } } - fn get_dtype(&self) -> DType { + pub fn get_dtype(&self) -> DType { self.dtype } - fn device(&self) -> &Device { + pub fn device(&self) -> &Device { &self.device } - fn reset_decoder(&mut self) -> Option { + pub fn reset_decoder(&mut self) -> Option { let ret = self.tokenizer.decode_rest().unwrap_or(None); self.tokenizer.clear(); ret } + + pub fn rank(&self) -> usize { + self.rank + } } unsafe impl Send for DefaultPipeline {} diff --git a/src/paged_attention/mod.rs b/src/paged_attention/mod.rs index eff3a1b..6daa4e6 100644 --- a/src/paged_attention/mod.rs +++ b/src/paged_attention/mod.rs @@ -59,13 +59,13 @@ impl PagedAttention { /// block_size] /// input_metadata: metadata for paged attention. pub fn forward( - &mut self, + &self, query: &Tensor, key: &Tensor, value: &Tensor, attention_mask: Option<&Tensor>, - mut key_cache: Option, - mut value_cache: Option, + key_cache: Option, + value_cache: Option, input_metadata: &InputMetadata, softcapping: Option, ) -> Result { @@ -146,8 +146,8 @@ impl PagedAttention { reshape_and_cache( &key, &value, - key_cache.as_mut().unwrap(), - value_cache.as_mut().unwrap(), + key_cache.as_ref().unwrap(), + value_cache.as_ref().unwrap(), &slot_mapping, )?; } diff --git a/src/scheduler/cache_engine.rs b/src/scheduler/cache_engine.rs index 7b7b717..dbef7c4 100644 --- a/src/scheduler/cache_engine.rs +++ b/src/scheduler/cache_engine.rs @@ -37,6 +37,7 @@ impl CacheConfig { pub type KVCache = (Tensor, Tensor); +#[derive(Debug)] pub struct CacheEngine { gpu_cache: Arc>>, cpu_cache: Vec, @@ -49,6 +50,7 @@ impl CacheEngine { cache_config: &CacheConfig, dtype: DType, device: &Device, + num_shards: usize, ) -> Result { Ok(Self { gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache( @@ -56,8 +58,15 @@ impl CacheEngine { cache_config, dtype, device, + num_shards, )?)), - cpu_cache: Self::allocate_cpu_cache(model_config, cache_config, dtype, device)?, + cpu_cache: Self::allocate_cpu_cache( + model_config, + cache_config, + dtype, + device, + num_shards, + )?, num_layers: model_config.num_hidden_layers, }) } @@ -75,13 +84,18 @@ impl CacheEngine { cache_config: &CacheConfig, dtype: DType, device: &Device, + num_shards: usize, ) -> Result, APIError> { assert!(cache_config.fully_init); - let key_block_shape = - Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size); + let key_block_shape = Self::calculate_key_block_shape( + model_config, + dtype, + cache_config.block_size, + num_shards, + ); let value_block_shape = - Self::calculate_value_block_shape(model_config, cache_config.block_size); + Self::calculate_value_block_shape(model_config, cache_config.block_size, num_shards); let mut gpu_cache = Vec::new(); for _ in 0..model_config.num_hidden_layers { let key_blocks = try_api!(Tensor::zeros( @@ -115,13 +129,18 @@ impl CacheEngine { cache_config: &CacheConfig, dtype: DType, device: &Device, + num_shards: usize, ) -> Result, APIError> { assert!(cache_config.fully_init); - let key_block_shape = - Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size); + let key_block_shape = Self::calculate_key_block_shape( + model_config, + dtype, + cache_config.block_size, + num_shards, + ); let value_block_shape = - Self::calculate_value_block_shape(model_config, cache_config.block_size); + Self::calculate_value_block_shape(model_config, cache_config.block_size, num_shards); let mut cpu_cache = Vec::new(); for _ in 0..model_config.num_hidden_layers { let key_blocks = try_api!(Tensor::zeros( @@ -156,11 +175,12 @@ impl CacheEngine { model_config: &Config, dtype: DType, block_size: usize, + num_shards: usize, ) -> (usize, usize, usize, usize) { let element_size = dtype.size_in_bytes(); let x = 16 / element_size; ( - model_config.num_key_value_heads, + model_config.num_key_value_heads / num_shards, model_config.get_head_size() / x, block_size, x, @@ -170,9 +190,10 @@ impl CacheEngine { fn calculate_value_block_shape( model_config: &Config, block_size: usize, + num_shards: usize, ) -> (usize, usize, usize) { ( - model_config.num_key_value_heads, + model_config.num_key_value_heads / num_shards, model_config.get_head_size(), block_size, ) From 31e655cb4d324561e6b3cdcbcce634fa0f7b1024 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 14 Jan 2025 08:49:24 +0000 Subject: [PATCH 4/6] Fix typo --- src/openai/pipelines/pipeline.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/openai/pipelines/pipeline.rs b/src/openai/pipelines/pipeline.rs index e3caaad..e37ea0a 100644 --- a/src/openai/pipelines/pipeline.rs +++ b/src/openai/pipelines/pipeline.rs @@ -262,7 +262,7 @@ impl DefaultLoader { "Loading partial model on device rank {} (ordinal {})", rank, *dev_id ); - let pathes: Vec = paths.get_weight_filenames(); + let paths: Vec = paths.get_weight_filenames(); let device = crate::new_device(*dev_id).unwrap(); let comm = Rc::new( Comm::from_rank( @@ -275,7 +275,7 @@ impl DefaultLoader { ); let vb = unsafe { candle_nn::var_builder::ShardedSafeTensors::var_builder( - &pathes, dtype, &device, + &paths, dtype, &device, ) .unwrap() }; @@ -366,7 +366,7 @@ impl DefaultLoader { }; (vec![model], vec![device], sep) } else { - panic!("You've provided multiple devices for inference but nccl feature is not enalbed!"); + panic!("You've provided multiple devices for inference but nccl feature is not enabled!"); }; (models, devices, config, sep_style) From d5d3afc1186b650e16b9dd9d8f58ac89f4c973d4 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 14 Jan 2025 08:53:59 +0000 Subject: [PATCH 5/6] Typo fix --- README.md | 8 ++++---- src/openai/pipelines/llm_engine.rs | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 4440de4..a9fd379 100644 --- a/README.md +++ b/README.md @@ -77,11 +77,11 @@ Run `Multi-GPU` inference with NCCL feature cargo run --release --features cuda,nccl -- --port 2000 --device-ids "0,1" --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --temperature 0. --penalty 1.0 ``` -If you encoutered problems under Multi-GPU setttings, you may: +If you encountered problems under Multi-GPU settings, you may: ```shell -export NCCL_P2P_LEVEL=LOC # use local devices (mutiple cards within a server, PCIE, etc.) -export NCCL_P2P_DISABLE=1 # diable p2p cause this feature can cause illegal memory access in certain environments -export NCCL_IB_DISABLE=1 # diable ibnet/infiniband (optional) +export NCCL_P2P_LEVEL=LOC # use local devices (multiple cards within a server, PCIE, etc.) +export NCCL_P2P_DISABLE=1 # disable p2p cause this feature can cause illegal memory access in certain environments +export NCCL_IB_DISABLE=1 # disable ibnet/infiniband (optional) ``` **Note:** quantized models are not supported yet under multi-gpu setting. diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index 1c10f84..68a959b 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -155,10 +155,10 @@ impl LLMEngine { finish_reason: Option, ) -> ChatCompletionChunk { let mut choices = Vec::new(); - let pipline = self.get_mut_pipeline(0).unwrap().0.as_mut(); + let pipeline = self.get_mut_pipeline(0).unwrap().0.as_mut(); let choice = Choice { delta: ChoiceData { - role: pipline.get_conversation(true).get_roles().0.clone(), + role: pipeline.get_conversation(true).get_roles().0.clone(), content, }, finish_reason, @@ -170,7 +170,7 @@ impl LLMEngine { id: request_id, choices, created, - model: pipline.name().to_string(), + model: pipeline.name().to_string(), object: "chat.completion.chunk", system_fingerprint: None, } From b01103c4d68e913c339f40c5a0de5977b2842d1a Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 14 Jan 2025 08:58:06 +0000 Subject: [PATCH 6/6] Bump the project version to 0.1.1 --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a20ed3a..52e3eb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -474,7 +474,7 @@ dependencies = [ [[package]] name = "candle-vllm" -version = "0.1.0" +version = "0.1.1" dependencies = [ "accelerate-src", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index e4b9d6c..99a9eb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-vllm" -version = "0.1.0" +version = "0.1.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html