From 743a8b249c73bed2a49356007a2a8f77ae1564ef Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 3 Jul 2024 09:54:56 +0800 Subject: [PATCH] Unified pipeline for models & support phi3 model (#45) * Optional logprobs & fix llama eos/stop token * Cargo fmt * Mention other options for chat completion request * Configurable kvcache & fix repeat chat history * Improve readability * Instructions for ChatUI & add demo chat video * Optimization for decoding stage & try to fix blocktable issue * Support stream response for chat completion * Update ReadMe & demo video * Reduce demo video size * Fix stream generation hang in release mode * Reduce the buffer size & update ReadMe * Fix LLaMa2 prompt instruction (for long conversation) * Cargo fmt * Padding to avoid block allocation issue & revision for prompt instruction * Unfied pipeline for models & support phi3 model * Fix padding strategy * Cargo fmt * Update ReadMe for supported models --- README.md | 24 +- src/lib.rs | 29 +- src/main.rs | 12 +- .../conversation/default_conversation.rs | 24 + src/openai/models/llama.rs | 106 +---- src/openai/models/mod.rs | 31 +- src/openai/models/phi3.rs | 427 ++++++++++++++++++ src/openai/openai_server.rs | 11 +- src/openai/pipelines/llm_engine.rs | 20 +- src/openai/pipelines/mod.rs | 6 +- .../pipelines/{llama.rs => pipeline.rs} | 124 +++-- src/paged_attention/mod.rs | 2 +- src/scheduler/cache_engine.rs | 30 +- tests/tests.rs | 9 +- 14 files changed, 652 insertions(+), 203 deletions(-) create mode 100644 src/openai/models/phi3.rs rename src/openai/pipelines/{llama.rs => pipeline.rs} (72%) diff --git a/README.md b/README.md index 23831b5..7e9e8c3 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,25 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a - Efficient management of key-value cache with PagedAttention. - Continuous batching. -### Pipelines -- Llama - - 7b - - 13b - - 70b +## Develop Status + +Currently, candle-vllm supports chat serving for the following models. + +| Model ID | Model Type | Supported | Speed (A100, BF16) +|--|--|--|--| +| #1 | **LLAMA/LLAMA2/LLaMa3** |✅|71 tks/s (7B)| +| #2 | Mistral |TBD|TBD| +| #3 | Phi (v1, v1.5, v2) |TBD|TBD| +| #4 | **Phi-3 (3.8B, 7B)** |✅|99 tks/s (3.8B)| +| #5 | Yi |TBD|TBD| +| #6 | StableLM |TBD|TBD| +| #7 | BigCode/StarCode |TBD|TBD| +| #8 | ChatGLM |TBD|TBD| +| #9 | QWen |TBD|TBD| +| #10 | Google Gemma |TBD|TBD| +| #11 | Blip-large (Multimodal) |TBD|TBD| +| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD| + ## Demo Chat with candle-vllm (71 tokens/s, LLaMa2 7B, bf16, on A100) diff --git a/src/lib.rs b/src/lib.rs index 5675719..45831cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ use candle::Result; use candle_core as candle; use clap::Subcommand; use openai::pipelines::{ - llama::{LlamaLoader, LlamaSpecificConfig}, + pipeline::{DefaultLoader, SpecificConfig}, ModelLoader, }; @@ -29,6 +29,13 @@ pub enum ModelSelected { #[arg(long)] repeat_last_n: usize, }, + + /// Select the phi3 3.8b model. + Phi3 { + /// Control the application of repeat penalty for the last n tokens + #[arg(long)] + repeat_last_n: usize, + }, } impl ToString for ModelSelected { @@ -37,6 +44,7 @@ impl ToString for ModelSelected { ModelSelected::Llama7b { repeat_last_n: _ } => "llama7b".to_string(), ModelSelected::Llama13b { repeat_last_n: _ } => "llama13b".to_string(), ModelSelected::Llama70b { repeat_last_n: _ } => "llama70b".to_string(), + ModelSelected::Phi3 { repeat_last_n: _ } => "phi3".to_string(), } } } @@ -44,26 +52,33 @@ impl ToString for ModelSelected { pub fn get_model_loader<'a>(selected_model: ModelSelected) -> (Box>, String) { match selected_model { ModelSelected::Llama7b { repeat_last_n } => ( - Box::new(LlamaLoader::new( - LlamaSpecificConfig::new(repeat_last_n), + Box::new(DefaultLoader::new( + SpecificConfig::new(repeat_last_n), "llama7b".to_string(), )), "meta-llama/Llama-2-7b-chat-hf".to_string(), ), ModelSelected::Llama13b { repeat_last_n } => ( - Box::new(LlamaLoader::new( - LlamaSpecificConfig::new(repeat_last_n), + Box::new(DefaultLoader::new( + SpecificConfig::new(repeat_last_n), "llama13b".to_string(), )), "meta-llama/Llama-2-13b-chat-hf".to_string(), ), ModelSelected::Llama70b { repeat_last_n } => ( - Box::new(LlamaLoader::new( - LlamaSpecificConfig::new(repeat_last_n), + Box::new(DefaultLoader::new( + SpecificConfig::new(repeat_last_n), "llama70b".to_string(), )), "meta-llama/Llama-2-70b-chat-hf".to_string(), ), + ModelSelected::Phi3 { repeat_last_n } => ( + Box::new(DefaultLoader::new( + SpecificConfig::new(repeat_last_n), + "phi3".to_string(), + )), + "microsoft/Phi-3-mini-4k-instruct".to_string(), + ), } } diff --git a/src/main.rs b/src/main.rs index 219ae22..50fd218 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,8 @@ use actix_web::{App, HttpServer}; use candle_core::{DType, Device}; use candle_examples; use candle_vllm::openai::openai_server::chat_completions; -use candle_vllm::openai::pipelines::llama::LlamaModelPaths; use candle_vllm::openai::pipelines::llm_engine::LLMEngine; +use candle_vllm::openai::pipelines::pipeline::DefaultModelPaths; use candle_vllm::openai::responses::APIError; use candle_vllm::openai::OpenAIServerData; use candle_vllm::scheduler::cache_engine::CacheConfig; @@ -77,7 +77,7 @@ async fn main() -> Result<(), APIError> { let (loader, model_id) = get_model_loader(args.command); let paths = match &args.weight_path { - Some(path) => Box::new(LlamaModelPaths { + Some(path) => Box::new(DefaultModelPaths { tokenizer_filename: (path.to_owned() + "tokenizer.json").into(), config_filename: (path.to_owned() + "config.json").into(), filenames: hub_load_local_safetensors(path, "model.safetensors.index.json").unwrap(), @@ -100,16 +100,16 @@ async fn main() -> Result<(), APIError> { let num_gpu_blocks = args.kvcache_mem_gpu * SIZE_IN_MB / dsize / args.block_size - / config.get_num_kv_heads() + / config.num_key_value_heads / config.get_head_size() - / config.get_num_hidden_layers() + / config.num_hidden_layers / 2; let num_cpu_blocks = args.kvcache_mem_cpu * SIZE_IN_MB / dsize / args.block_size - / config.get_num_kv_heads() + / config.num_key_value_heads / config.get_head_size() - / config.get_num_hidden_layers() + / config.num_hidden_layers / 2; let cache_config = CacheConfig { block_size: args.block_size, diff --git a/src/openai/conversation/default_conversation.rs b/src/openai/conversation/default_conversation.rs index f4e17b4..c43b318 100644 --- a/src/openai/conversation/default_conversation.rs +++ b/src/openai/conversation/default_conversation.rs @@ -17,6 +17,7 @@ pub enum SeparatorStyle { NoColonTwo, AddNewLineSingle, Llama2, + Phi, ChatGLM, ChatML, ChatIntern, @@ -242,6 +243,29 @@ impl Conversation for DefaultConversation { accum } + SeparatorStyle::Phi => { + let mut accum = "".to_string(); + for (i, message) in self.messages.iter().enumerate() { + let Message((_role, message)) = message; + if _role.clone() == self.roles.0 { + //user message + if let Some(message) = message { + accum += &format!("<|user|> {message}<|end|>"); + } else { + accum += &format!("<|user|> <|end|"); + } + } else if _role.clone() == self.roles.1 { + //assistant message + if let Some(message) = message { + accum += &format!("<|assistant|>{message}<|end|>"); + } + } else if i == 0 && !system_prompt.is_empty() { + accum += &system_prompt; + } + } + accum + } + SeparatorStyle::ChatGLM => { let round_add_n = if self.name == "chatglm2" { 1 } else { 0 }; diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index 45e5a2b..ae41071 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -1,4 +1,4 @@ -use super::ConfigLike; +use super::Config; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; use candle::{DType, Device, IndexOp, Result, Tensor, D}; @@ -23,37 +23,10 @@ pub struct LlamaConfig { pub eos_token_id: Option, } -impl LlamaConfig { - pub fn num_key_value_heads(&self) -> usize { - self.num_key_value_heads.unwrap_or(self.num_attention_heads) - } -} - fn default_rope() -> f32 { 10_000.0 } -impl ConfigLike for LlamaConfig { - fn get_num_kv_heads(&self) -> usize { - self.num_key_value_heads.unwrap_or(self.num_attention_heads) - } - fn get_hidden_size(&self) -> usize { - self.hidden_size - } - fn get_num_hidden_layers(&self) -> usize { - self.num_hidden_layers - } - fn get_num_attention_heads(&self) -> usize { - self.num_attention_heads - } - fn get_vocab_size(&self) -> usize { - self.vocab_size - } - fn get_sliding_window(&self) -> Option { - None - } -} - impl LlamaConfig { pub fn into_config(self, use_flash_attn: bool) -> Config { Config { @@ -68,80 +41,13 @@ impl LlamaConfig { use_flash_attn, bos_token_id: self.bos_token_id, eos_token_id: self.eos_token_id, + max_seq_len: MAX_SEQ_LEN, + sliding_window: None, + hidden_act: None, } } } -#[derive(Debug, Clone)] -pub struct Config { - pub hidden_size: usize, - pub intermediate_size: usize, - pub vocab_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub num_key_value_heads: usize, - pub use_flash_attn: bool, - pub rms_norm_eps: f64, - pub rope_theta: f32, - pub bos_token_id: Option, - pub eos_token_id: Option, -} - -impl Config { - pub fn config_7b_v1(use_flash_attn: bool) -> Self { - Self { - hidden_size: 4096, - intermediate_size: 11008, - vocab_size: 32000, - num_hidden_layers: 32, - num_attention_heads: 32, - num_key_value_heads: 32, - use_flash_attn, - rms_norm_eps: 1e-6, - rope_theta: 10_000.0, - bos_token_id: None, - eos_token_id: None, - } - } - - pub fn config_7b_v2(use_flash_attn: bool) -> Self { - Self { - hidden_size: 4096, - intermediate_size: 11008, - vocab_size: 32000, - num_hidden_layers: 32, - num_attention_heads: 32, - num_key_value_heads: 32, - use_flash_attn, - rms_norm_eps: 1e-5, - rope_theta: 10_000.0, - bos_token_id: None, - eos_token_id: None, - } - } -} - -impl ConfigLike for Config { - fn get_num_kv_heads(&self) -> usize { - self.num_key_value_heads - } - fn get_hidden_size(&self) -> usize { - self.hidden_size - } - fn get_num_hidden_layers(&self) -> usize { - self.num_hidden_layers - } - fn get_num_attention_heads(&self) -> usize { - self.num_attention_heads - } - fn get_vocab_size(&self) -> usize { - self.vocab_size - } - fn get_sliding_window(&self) -> Option { - None - } -} - #[derive(Debug, Clone)] pub struct Cache { masks: HashMap, @@ -159,9 +65,9 @@ impl Cache { .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) .collect(); let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + let idx_theta = Tensor::arange(0, config.max_seq_len as u32, device)? .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? + .reshape((config.max_seq_len, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; let cos = idx_theta.cos()?.to_dtype(dtype)?; let sin = idx_theta.sin()?.to_dtype(dtype)?; diff --git a/src/openai/models/mod.rs b/src/openai/models/mod.rs index 545c9bc..a4ddf19 100644 --- a/src/openai/models/mod.rs +++ b/src/openai/models/mod.rs @@ -1,13 +1,26 @@ pub mod llama; +pub mod phi3; -pub trait ConfigLike { - fn get_num_kv_heads(&self) -> usize; - fn get_hidden_size(&self) -> usize; - fn get_num_hidden_layers(&self) -> usize; - fn get_num_attention_heads(&self) -> usize; - fn get_vocab_size(&self) -> usize; - fn get_sliding_window(&self) -> Option; - fn get_head_size(&self) -> usize { - self.get_hidden_size() / self.get_num_attention_heads() +#[derive(Debug, Clone)] +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub use_flash_attn: bool, + pub rms_norm_eps: f64, + pub rope_theta: f32, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub max_seq_len: usize, + pub sliding_window: Option, + pub hidden_act: Option, +} + +impl Config { + pub fn get_head_size(&self) -> usize { + self.hidden_size / self.num_attention_heads } } diff --git a/src/openai/models/phi3.rs b/src/openai/models/phi3.rs new file mode 100644 index 0000000..62421bd --- /dev/null +++ b/src/openai/models/phi3.rs @@ -0,0 +1,427 @@ +// This implementation is based on: +// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py +use super::Config; +use crate::paged_attention::input_metadata::InputMetadata; +use crate::paged_attention::PagedAttention; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_core as candle; +use candle_nn::LayerNorm; +use candle_nn::VarBuilder; +use candle_transformers::models::with_tracing::{linear_no_bias as linear, Linear}; +use std::iter::zip; +use std::sync::Arc; + +// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json +#[derive(Debug, Clone, serde::Deserialize)] +pub struct PhiConfig { + pub vocab_size: usize, + pub hidden_act: candle_nn::Activation, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub rope_scaling: Option, + pub max_position_embeddings: usize, + pub sliding_window: Option, +} + +impl PhiConfig { + pub fn into_config(self, use_flash_attn: bool) -> Config { + Config { + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads, + rms_norm_eps: self.rms_norm_eps, + rope_theta: self.rope_theta as f32, + use_flash_attn, + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, + max_seq_len: self.max_position_embeddings, + sliding_window: self.sliding_window, + hidden_act: Some(self.hidden_act), + } + } +} + +pub struct RmsNorm { + //High-precision RmsNorm + norm: LayerNorm, + weight: Tensor, + eps: f64, +} + +impl RmsNorm { + pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get_with_hints(size, "weight", candle_nn::Init::Const(1.))?; + let weight_f32 = weight.to_dtype(DType::F32)?; + Ok(RmsNorm { + norm: LayerNorm::rms_norm(weight, eps), + weight: weight_f32, + eps, + }) + } +} + +impl Module for RmsNorm { + fn forward(&self, xs: &Tensor) -> Result { + if xs.is_contiguous() { + let dtype = xs.dtype(); + candle_nn::ops::rms_norm(&xs.to_dtype(DType::F32)?, &self.weight, self.eps as f32)? + .to_dtype(dtype) + } else { + self.norm.forward(xs) + } + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, + cos_sin: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_seq_len; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let cos_sin = Tensor::cat(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?.contiguous()?; //must be contiguous tensor; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + cos_sin, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +struct Attention { + qkv_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + hidden_size: usize, + head_dim: usize, + rotary_emb: Arc, + attn: PagedAttention, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim; + let qkv_proj = linear(cfg.hidden_size, op_size, vb.pp("qkv_proj"))?; + let o_proj = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?; + Ok(Self { + qkv_proj, + o_proj, + rotary_emb, + num_heads, + num_kv_heads, + num_kv_groups: num_heads / num_kv_heads, + head_dim, + hidden_size: cfg.hidden_size, + attn: PagedAttention::new( + cfg.num_attention_heads, + head_dim, + 1. / ((head_dim as f32).sqrt()), + Some(cfg.num_key_value_heads), + None, + vb.device().clone(), + None, + )?, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + cache: Option<(&Tensor, &Tensor)>, + input_metadata: &mut InputMetadata, + ) -> Result { + let (b_sz, seq_len, _) = xs.dims3()?; + + let qkv = self.qkv_proj.forward(xs)?; + let query_pos = self.num_heads * self.head_dim; + let query_states = qkv.narrow(D::Minus1, 0, query_pos)?.contiguous()?; + let key_states = qkv + .narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)? + .contiguous()?; + let value_states = qkv + .narrow( + D::Minus1, + query_pos + self.num_kv_heads * self.head_dim, + self.num_kv_heads * self.head_dim, + )? + .contiguous()?; + + let (q, k, v) = if seq_len == 1 { + //no need transpose for seq_len == 1, change reshape dim + let q = query_states.reshape((b_sz, self.num_heads, seq_len, self.head_dim))?; + let k = key_states.reshape((b_sz, self.num_kv_heads, seq_len, self.head_dim))?; + let v = value_states.reshape((b_sz, self.num_kv_heads, seq_len, self.head_dim))?; + (q, k, v) + } else { + let q = query_states + .reshape((b_sz, seq_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = key_states + .reshape((b_sz, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = value_states + .reshape((b_sz, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + (q, k, v.contiguous()?) + }; + + //preserve the precision with F32 type + let (q, k) = self.rotary_emb.apply_rotary_emb_qkv( + &q.to_dtype(DType::F32)?, + &k.to_dtype(DType::F32)?, + seqlen_offset, + )?; + let q = q.to_dtype(v.dtype())?; + let k = k.to_dtype(v.dtype())?; + + let k = candle_transformers::utils::repeat_kv(k, self.num_kv_groups)?.contiguous()?; + let v = candle_transformers::utils::repeat_kv(v, self.num_kv_groups)?.contiguous()?; + let y = self.attn.forward( + &q, + &k, + &v, + attention_mask, + cache.map(|(k_, _)| k_.clone()), + cache.map(|(_, v_)| v_.clone()), + input_metadata, + )?; + + let y = if attention_mask.is_some() { + y.transpose(1, 2)? + .reshape(&[b_sz, seq_len, self.hidden_size])? + } else { + y.reshape(&[b_sz, seq_len, self.hidden_size])? + }; + let y = self.o_proj.forward(&y)?; + Ok(y) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + gate_up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, + i_size: usize, +} + +impl Mlp { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let gate_up_proj = linear(hidden_size, 2 * i_size, vb.pp("gate_up_proj"))?; + let down_proj = linear(i_size, hidden_size, vb.pp("down_proj"))?; + Ok(Self { + gate_up_proj, + down_proj, + act_fn: cfg.hidden_act.unwrap(), + i_size, + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let up_states = xs.apply(&self.gate_up_proj)?; + let gate = up_states.narrow(D::Minus1, 0, self.i_size)?; + let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?; + let up_states = (up_states * gate.apply(&self.act_fn))?; + up_states.apply(&self.down_proj) + } +} + +struct DecoderLayer { + self_attn: Attention, + mlp: Mlp, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = Mlp::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + cache: Option<(&Tensor, &Tensor)>, + input_metadata: &mut InputMetadata, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = + self.self_attn + .forward(&xs, attention_mask, seqlen_offset, cache, input_metadata)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } +} + +pub struct Phi { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: Linear, + device: Device, + dtype: DType, + cfg: Config, +} + +impl Phi { + pub fn new(vb: VarBuilder, cfg: &Config, dtype: DType, device: &Device) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(dtype, cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: device.clone(), + dtype: dtype, + cfg: cfg.clone(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + seqlen_offset: usize, + kv_caches: Option<&Vec<(Tensor, Tensor)>>, + input_metadata: &mut InputMetadata, + ) -> Result { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + + if let Some(kv_caches) = kv_caches { + for ((k_cache, v_cache), layer) in zip(kv_caches.iter(), self.layers.iter_mut()) { + xs = layer.forward( + &xs, + attention_mask.as_ref(), + seqlen_offset, + Some((k_cache, v_cache)), + input_metadata, + )? + } + } else { + for layer in self.layers.iter_mut() { + xs = layer.forward( + &xs, + attention_mask.as_ref(), + seqlen_offset, + None, + input_metadata, + )? + } + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head)? + .i((.., 0, ..))? + .squeeze(0)? + .to_dtype(DType::F32) + } + + pub fn get_config(&self) -> &Config { + &self.cfg + } +} diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index 786f2a9..b4c779a 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -125,7 +125,7 @@ async fn chat_completions( if prompt.is_err() { return Either::Left(Err(prompt.err().unwrap())); } - let mut prompt = prompt.unwrap(); + let prompt = prompt.unwrap(); let token_ids = check_length(&request, prompt.clone(), &data).await; if token_ids.is_err() { @@ -134,8 +134,13 @@ async fn chat_completions( let mut token_ids: Encoding = token_ids.unwrap(); if token_ids.len() % 2 == 0 { //padding to avoid block allocation issue - prompt += "\n"; - token_ids = check_length(&request, prompt.clone(), &data).await.unwrap(); + token_ids.pad( + token_ids.len() + 1, + 0, + 0, + "\n", + tokenizers::PaddingDirection::Right, + ); } println!("\n\n\nPrompt {:?}", prompt); diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index 03fa26d..25abf03 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -64,7 +64,7 @@ impl<'a> LLMEngine<'a> { pipeline.get_dtype(), &pipeline.device(), )?; - let sliding_window = pipeline.get_model_config().get_sliding_window(); + let sliding_window = pipeline.get_model_config().sliding_window; Ok(Self { pipeline, scheduler: Scheduler::new(scheduler_config, &cache_config), @@ -344,7 +344,11 @@ impl<'a> LLMEngine<'a> { .collect::>(); let start_idx = if let Some(sliding_window) = self.sliding_window { - 0.min(prompt_len - sliding_window) + if prompt_len > sliding_window { + 0.min(prompt_len - sliding_window) + } else { + 0 + } } else { 0 }; @@ -458,12 +462,12 @@ impl<'a> LLMEngine<'a> { if let Some(sliding_window) = self.sliding_window { let sliding_window_blocks = sliding_window / self.cache_config.block_size; - block_tables.push( - table - .get(table.len() - sliding_window_blocks..) - .unwrap() - .to_vec(), - ); + let slide_idx = if table.len() > sliding_window_blocks { + table.len() - sliding_window_blocks + } else { + 0 + }; + block_tables.push(table.get(slide_idx..).unwrap().to_vec()); } else { block_tables.push(table); } diff --git a/src/openai/pipelines/mod.rs b/src/openai/pipelines/mod.rs index 0772ba8..c288532 100644 --- a/src/openai/pipelines/mod.rs +++ b/src/openai/pipelines/mod.rs @@ -9,14 +9,14 @@ use crate::{ }; use super::{ - conversation::Conversation, models::ConfigLike, responses::APIError, + conversation::Conversation, models::Config, responses::APIError, sampling_params::SamplingParams, PipelineConfig, TokenizerWrapper, }; use candle_examples::token_output_stream::TokenOutputStream; -pub mod llama; /// The LLMEngine is effectively a wrapper around a ModulePipeline. It contains a Scheduler and a CacheEngine /// which are used to scheduler and manage the cache during generation requests, respectively. pub mod llm_engine; +pub mod pipeline; type TokenOrFinishReason = Either; @@ -42,7 +42,7 @@ pub trait ModulePipeline<'s>: Send + Sync { fn get_conversation(&mut self, with_history: bool) -> &mut dyn Conversation; - fn get_model_config(&self) -> Box; + fn get_model_config(&self) -> Config; fn get_dtype(&self) -> DType; diff --git a/src/openai/pipelines/llama.rs b/src/openai/pipelines/pipeline.rs similarity index 72% rename from src/openai/pipelines/llama.rs rename to src/openai/pipelines/pipeline.rs index 8daee5f..d372732 100644 --- a/src/openai/pipelines/llama.rs +++ b/src/openai/pipelines/pipeline.rs @@ -9,8 +9,9 @@ use crate::{ Conversation, }, models::{ - llama::{Cache, Config, Llama, LlamaConfig}, - ConfigLike, + llama::{Cache, Llama, LlamaConfig}, + phi3::{Phi, PhiConfig}, + Config, }, requests::StopTokens, responses::APIError, @@ -34,20 +35,24 @@ const EOS_TOKEN: &str = ""; const SAMPLING_SEED: u64 = 299792458; #[derive(Debug, Clone)] -pub struct LlamaSpecificConfig { +pub struct SpecificConfig { repeat_last_n: usize, } -impl LlamaSpecificConfig { +impl SpecificConfig { pub fn new(repeat_last_n: usize) -> Self { Self { repeat_last_n } } } +enum LLMModel { + LLAMA(Llama), + Phi3(Phi), +} /// top-p, multinomial, and argmax sampling are implemented. Beam search is not implemented. -pub struct LlamaPipeline { - llama: Llama, - args: LlamaSpecificConfig, +pub struct DefaultPipeline { + model: LLMModel, + args: SpecificConfig, tokenizer: TokenOutputStream, logits_processor: LogitsProcessor, conversation: DefaultConversation, @@ -58,18 +63,18 @@ pub struct LlamaPipeline { config: Config, } -pub struct LlamaLoader { - config: LlamaSpecificConfig, +pub struct DefaultLoader { + config: SpecificConfig, name: String, } -pub struct LlamaModelPaths

{ +pub struct DefaultModelPaths

{ pub tokenizer_filename: P, pub config_filename: P, pub filenames: Vec

, } -impl ModelPaths for LlamaModelPaths { +impl ModelPaths for DefaultModelPaths { fn get_config_filename(&self) -> &PathBuf { &self.config_filename } @@ -81,13 +86,13 @@ impl ModelPaths for LlamaModelPaths { } } -impl LlamaLoader { - pub fn new(config: LlamaSpecificConfig, name: String) -> Self { +impl DefaultLoader { + pub fn new(config: SpecificConfig, name: String) -> Self { Self { config, name } } } -impl<'a> ModelLoader<'a> for LlamaLoader { +impl<'a> ModelLoader<'a> for DefaultLoader { fn download_model( &self, model_id: String, @@ -117,7 +122,7 @@ impl<'a> ModelLoader<'a> for LlamaLoader { filenames.push(filename); } - Ok(Box::new(LlamaModelPaths { + Ok(Box::new(DefaultModelPaths { tokenizer_filename, config_filename, filenames, @@ -132,10 +137,21 @@ impl<'a> ModelLoader<'a> for LlamaLoader { ) -> Result<(Box>, PipelineConfig), APIError> { let args = self.config.clone(); - let config: LlamaConfig = try_api!(serde_json::from_slice(&try_api!(std::fs::read( - paths.get_config_filename() - )),)); - let config = config.into_config(false); + let config = match self.name.as_str() { + "llama7b" | "llama13b" | "llama70b" => { + let config: LlamaConfig = try_api!(serde_json::from_slice(&try_api!( + std::fs::read(paths.get_config_filename()) + ),)); + config.into_config(false) + } + "phi3" => { + let config: PhiConfig = try_api!(serde_json::from_slice(&try_api!(std::fs::read( + paths.get_config_filename() + )),)); + config.into_config(false) + } + _ => panic!(""), + }; println!("Loading {} model.", self.name); @@ -146,7 +162,17 @@ impl<'a> ModelLoader<'a> for LlamaLoader { _ => panic!("Load model weights failed!"), }; - let llama = try_api!(Llama::load(vb, &config, dtype, &device)); + let (model, sep_style) = match self.name.as_str() { + "llama7b" | "llama13b" | "llama70b" => ( + LLMModel::LLAMA(try_api!(Llama::load(vb, &config, dtype, &device))), + SeparatorStyle::Llama2, + ), + "phi3" => ( + LLMModel::Phi3(try_api!(Phi::new(vb, &config, dtype, &device))), + SeparatorStyle::Phi, + ), + _ => panic!(""), + }; let tokenizer_ = Tokenizer::from_file(paths.get_tokenizer_filename()) .map_err(|x| APIError::new(x.to_string()))?; @@ -163,17 +189,17 @@ impl<'a> ModelLoader<'a> for LlamaLoader { //reference: https://huggingface.co/blog/codellama#conversational-instructions, //reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 Ok(( - Box::new(LlamaPipeline { - llama, + Box::new(DefaultPipeline { + model, args, tokenizer, logits_processor: LogitsProcessor::new(SAMPLING_SEED, None, None), conversation: DefaultConversation::new( - "llama-2".to_string(), - "[INST] <>\n{}\n<>\n\n [/INST] ".to_string(), + self.name.to_string(), + "[INST] <>\n{}\n<>\n\n [/INST]".to_string(), Vec::default(), 0, - SeparatorStyle::Llama2, + sep_style, "".to_string(), Vec::default(), ("user".to_string(), "assistant".to_string()), @@ -193,7 +219,7 @@ impl<'a> ModelLoader<'a> for LlamaLoader { } } -impl<'s> ModulePipeline<'s> for LlamaPipeline { +impl<'s> ModulePipeline<'s> for DefaultPipeline { fn forward( &mut self, input_tokens: Tensor, @@ -205,18 +231,30 @@ impl<'s> ModulePipeline<'s> for LlamaPipeline { if length > 1 { self.cur_idx = 0; } + let ret = match &mut self.model { + LLMModel::LLAMA(llama) => llama + .forward( + &input_tokens + .reshape((1, input_tokens.shape().dims()[0])) + .unwrap(), + self.cur_idx, + kv_cache, + &mut input_metadata, + ) + .map_err(APIError::from), + LLMModel::Phi3(phi) => phi + .forward( + &input_tokens + .reshape((1, input_tokens.shape().dims()[0])) + .unwrap(), + self.cur_idx, + kv_cache, + &mut input_metadata, + ) + .map_err(APIError::from), + _ => panic!("Not supported model!"), + }; - let ret = self - .llama - .forward( - &input_tokens - .reshape((1, input_tokens.shape().dims()[0])) - .unwrap(), - self.cur_idx, - kv_cache, - &mut input_metadata, - ) - .map_err(APIError::from); self.cur_idx += length; return ret; } @@ -307,8 +345,12 @@ impl<'s> ModulePipeline<'s> for LlamaPipeline { &mut self.conversation } - fn get_model_config(&self) -> Box { - Box::new(self.llama.get_config().clone()) + fn get_model_config(&self) -> Config { + match &self.model { + LLMModel::LLAMA(llama) => llama.get_config().clone(), + LLMModel::Phi3(phi) => phi.get_config().clone(), + _ => panic!("Not supported model!"), + } } fn get_dtype(&self) -> DType { @@ -320,5 +362,5 @@ impl<'s> ModulePipeline<'s> for LlamaPipeline { } } -unsafe impl Send for LlamaPipeline {} -unsafe impl Sync for LlamaPipeline {} +unsafe impl Send for DefaultPipeline {} +unsafe impl Sync for DefaultPipeline {} diff --git a/src/paged_attention/mod.rs b/src/paged_attention/mod.rs index 091749b..ed4b10a 100644 --- a/src/paged_attention/mod.rs +++ b/src/paged_attention/mod.rs @@ -86,7 +86,7 @@ impl PagedAttention { Some(mask) => { let att = (query.matmul(&key.t()?)? * self.scale as f64)?; let att = att.broadcast_add(mask)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; + let att = candle_nn::ops::softmax_last_dim(&att)?; Some(att.matmul(&value)?) } }; diff --git a/src/scheduler/cache_engine.rs b/src/scheduler/cache_engine.rs index 7c9c86d..09bf170 100644 --- a/src/scheduler/cache_engine.rs +++ b/src/scheduler/cache_engine.rs @@ -7,7 +7,7 @@ use candle_core::{DType, Device, Tensor}; use crate::{ backend::{copy_blocks, swap_blocks}, - openai::{models::ConfigLike, responses::APIError}, + openai::{models::Config, responses::APIError}, try_api, }; @@ -44,20 +44,20 @@ pub struct CacheEngine { impl CacheEngine { pub fn new( - model_config: Box, + model_config: Config, cache_config: CacheConfig, dtype: DType, device: &Device, ) -> Result { Ok(Self { gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache( - &*model_config, + &model_config, &cache_config, dtype, device, )?)), - cpu_cache: Self::allocate_cpu_cache(&*model_config, &cache_config, dtype, device)?, - num_layers: model_config.get_num_hidden_layers(), + cpu_cache: Self::allocate_cpu_cache(&model_config, &cache_config, dtype, device)?, + num_layers: model_config.num_hidden_layers, }) } @@ -70,7 +70,7 @@ impl CacheEngine { } fn allocate_gpu_cache( - model_config: &dyn ConfigLike, + model_config: &Config, cache_config: &CacheConfig, dtype: DType, device: &Device, @@ -82,7 +82,7 @@ impl CacheEngine { let value_block_shape = Self::calculate_value_block_shape(model_config, cache_config.block_size); let mut gpu_cache = Vec::new(); - for _ in 0..model_config.get_num_hidden_layers() { + for _ in 0..model_config.num_hidden_layers { let key_blocks = try_api!(Tensor::zeros( ( cache_config.num_gpu_blocks.unwrap(), @@ -110,7 +110,7 @@ impl CacheEngine { } fn allocate_cpu_cache( - model_config: &dyn ConfigLike, + model_config: &Config, cache_config: &CacheConfig, dtype: DType, device: &Device, @@ -122,7 +122,7 @@ impl CacheEngine { let value_block_shape = Self::calculate_value_block_shape(model_config, cache_config.block_size); let mut cpu_cache = Vec::new(); - for _ in 0..model_config.get_num_hidden_layers() { + for _ in 0..model_config.num_hidden_layers { let key_blocks = try_api!(Tensor::zeros( ( cache_config.num_cpu_blocks.unwrap(), @@ -152,27 +152,27 @@ impl CacheEngine { impl CacheEngine { fn calculate_key_block_shape( - model_config: &dyn ConfigLike, + model_config: &Config, dtype: DType, block_size: usize, ) -> (usize, usize, usize, usize) { let element_size = dtype.size_in_bytes(); let x = 16 / element_size; ( - model_config.get_num_kv_heads(), - model_config.get_head_size() / x, + model_config.num_key_value_heads, + model_config.hidden_size / model_config.num_attention_heads / x, block_size, x, ) } fn calculate_value_block_shape( - model_config: &dyn ConfigLike, + model_config: &Config, block_size: usize, ) -> (usize, usize, usize) { ( - model_config.get_num_kv_heads(), - model_config.get_head_size(), + model_config.num_key_value_heads, + model_config.hidden_size / model_config.num_attention_heads, block_size, ) } diff --git a/tests/tests.rs b/tests/tests.rs index c340f84..5a6f542 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,8 +1,3 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; - use actix_web::{http::header::ContentType, test, web::Data, App}; use candle_core::{DType, Device}; use candle_vllm::{ @@ -14,6 +9,8 @@ use candle_vllm::{ scheduler::{cache_engine::CacheConfig, SchedulerConfig}, ModelSelected, }; +use futures::lock::Mutex; +use std::{collections::HashMap, sync::Arc}; #[actix_web::test] async fn test_llama() -> Result<(), APIError> { @@ -40,6 +37,7 @@ async fn test_llama() -> Result<(), APIError> { pipeline_config: model.1, model: Arc::new(Mutex::new(llm_engine)), device: Device::Cpu, + record_conversation: false, }; let app = test::init_service( @@ -84,6 +82,7 @@ async fn test_llama() -> Result<(), APIError> { skip_special_tokens: None, ignore_eos: None, stop_token_ids: None, + logprobs: None, }) .to_request();