Skip to content

Commit

Permalink
Unified pipeline for models & support phi3 model (#45)
Browse files Browse the repository at this point in the history
* Optional logprobs & fix llama eos/stop token

* Cargo fmt

* Mention other options for chat completion request

* Configurable kvcache & fix repeat chat history

* Improve readability

* Instructions for ChatUI & add demo chat video

* Optimization for decoding stage & try to fix blocktable issue

* Support stream response for chat completion

* Update ReadMe & demo video

* Reduce demo video size

* Fix stream generation hang in release mode

* Reduce the buffer size & update ReadMe

* Fix LLaMa2 prompt instruction (for long conversation)

* Cargo fmt

* Padding to avoid block allocation issue & revision for prompt instruction

* Unfied pipeline for models & support phi3 model

* Fix padding strategy

* Cargo fmt

* Update ReadMe for supported models
  • Loading branch information
guoqingbao authored Jul 3, 2024
1 parent ae35a3a commit 743a8b2
Show file tree
Hide file tree
Showing 14 changed files with 652 additions and 203 deletions.
24 changes: 19 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,25 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a
- Efficient management of key-value cache with PagedAttention.
- Continuous batching.

### Pipelines
- Llama
- 7b
- 13b
- 70b
## Develop Status

Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, BF16)
|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3** ||71 tks/s (7B)|
| #2 | Mistral |TBD|TBD|
| #3 | Phi (v1, v1.5, v2) |TBD|TBD|
| #4 | **Phi-3 (3.8B, 7B)** ||99 tks/s (3.8B)|
| #5 | Yi |TBD|TBD|
| #6 | StableLM |TBD|TBD|
| #7 | BigCode/StarCode |TBD|TBD|
| #8 | ChatGLM |TBD|TBD|
| #9 | QWen |TBD|TBD|
| #10 | Google Gemma |TBD|TBD|
| #11 | Blip-large (Multimodal) |TBD|TBD|
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|


## Demo Chat with candle-vllm (71 tokens/s, LLaMa2 7B, bf16, on A100)
<img src="./res/candle-vllm-demo.gif" width="90%" height="90%" >
Expand Down
29 changes: 22 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use candle::Result;
use candle_core as candle;
use clap::Subcommand;
use openai::pipelines::{
llama::{LlamaLoader, LlamaSpecificConfig},
pipeline::{DefaultLoader, SpecificConfig},
ModelLoader,
};

Expand All @@ -29,6 +29,13 @@ pub enum ModelSelected {
#[arg(long)]
repeat_last_n: usize,
},

/// Select the phi3 3.8b model.
Phi3 {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: usize,
},
}

impl ToString for ModelSelected {
Expand All @@ -37,33 +44,41 @@ impl ToString for ModelSelected {
ModelSelected::Llama7b { repeat_last_n: _ } => "llama7b".to_string(),
ModelSelected::Llama13b { repeat_last_n: _ } => "llama13b".to_string(),
ModelSelected::Llama70b { repeat_last_n: _ } => "llama70b".to_string(),
ModelSelected::Phi3 { repeat_last_n: _ } => "phi3".to_string(),
}
}
}

pub fn get_model_loader<'a>(selected_model: ModelSelected) -> (Box<dyn ModelLoader<'a>>, String) {
match selected_model {
ModelSelected::Llama7b { repeat_last_n } => (
Box::new(LlamaLoader::new(
LlamaSpecificConfig::new(repeat_last_n),
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"llama7b".to_string(),
)),
"meta-llama/Llama-2-7b-chat-hf".to_string(),
),
ModelSelected::Llama13b { repeat_last_n } => (
Box::new(LlamaLoader::new(
LlamaSpecificConfig::new(repeat_last_n),
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"llama13b".to_string(),
)),
"meta-llama/Llama-2-13b-chat-hf".to_string(),
),
ModelSelected::Llama70b { repeat_last_n } => (
Box::new(LlamaLoader::new(
LlamaSpecificConfig::new(repeat_last_n),
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"llama70b".to_string(),
)),
"meta-llama/Llama-2-70b-chat-hf".to_string(),
),
ModelSelected::Phi3 { repeat_last_n } => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"phi3".to_string(),
)),
"microsoft/Phi-3-mini-4k-instruct".to_string(),
),
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use actix_web::{App, HttpServer};
use candle_core::{DType, Device};
use candle_examples;
use candle_vllm::openai::openai_server::chat_completions;
use candle_vllm::openai::pipelines::llama::LlamaModelPaths;
use candle_vllm::openai::pipelines::llm_engine::LLMEngine;
use candle_vllm::openai::pipelines::pipeline::DefaultModelPaths;
use candle_vllm::openai::responses::APIError;
use candle_vllm::openai::OpenAIServerData;
use candle_vllm::scheduler::cache_engine::CacheConfig;
Expand Down Expand Up @@ -77,7 +77,7 @@ async fn main() -> Result<(), APIError> {
let (loader, model_id) = get_model_loader(args.command);

let paths = match &args.weight_path {
Some(path) => Box::new(LlamaModelPaths {
Some(path) => Box::new(DefaultModelPaths {
tokenizer_filename: (path.to_owned() + "tokenizer.json").into(),
config_filename: (path.to_owned() + "config.json").into(),
filenames: hub_load_local_safetensors(path, "model.safetensors.index.json").unwrap(),
Expand All @@ -100,16 +100,16 @@ async fn main() -> Result<(), APIError> {
let num_gpu_blocks = args.kvcache_mem_gpu * SIZE_IN_MB
/ dsize
/ args.block_size
/ config.get_num_kv_heads()
/ config.num_key_value_heads
/ config.get_head_size()
/ config.get_num_hidden_layers()
/ config.num_hidden_layers
/ 2;
let num_cpu_blocks = args.kvcache_mem_cpu * SIZE_IN_MB
/ dsize
/ args.block_size
/ config.get_num_kv_heads()
/ config.num_key_value_heads
/ config.get_head_size()
/ config.get_num_hidden_layers()
/ config.num_hidden_layers
/ 2;
let cache_config = CacheConfig {
block_size: args.block_size,
Expand Down
24 changes: 24 additions & 0 deletions src/openai/conversation/default_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum SeparatorStyle {
NoColonTwo,
AddNewLineSingle,
Llama2,
Phi,
ChatGLM,
ChatML,
ChatIntern,
Expand Down Expand Up @@ -242,6 +243,29 @@ impl Conversation for DefaultConversation {
accum
}

SeparatorStyle::Phi => {
let mut accum = "".to_string();
for (i, message) in self.messages.iter().enumerate() {
let Message((_role, message)) = message;
if _role.clone() == self.roles.0 {
//user message
if let Some(message) = message {
accum += &format!("<|user|> {message}<|end|>");
} else {
accum += &format!("<|user|> <|end|");
}
} else if _role.clone() == self.roles.1 {
//assistant message
if let Some(message) = message {
accum += &format!("<|assistant|>{message}<|end|>");
}
} else if i == 0 && !system_prompt.is_empty() {
accum += &system_prompt;
}
}
accum
}

SeparatorStyle::ChatGLM => {
let round_add_n = if self.name == "chatglm2" { 1 } else { 0 };

Expand Down
106 changes: 6 additions & 100 deletions src/openai/models/llama.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::ConfigLike;
use super::Config;
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
Expand All @@ -23,37 +23,10 @@ pub struct LlamaConfig {
pub eos_token_id: Option<u32>,
}

impl LlamaConfig {
pub fn num_key_value_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}
}

fn default_rope() -> f32 {
10_000.0
}

impl ConfigLike for LlamaConfig {
fn get_num_kv_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}
fn get_hidden_size(&self) -> usize {
self.hidden_size
}
fn get_num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn get_num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn get_vocab_size(&self) -> usize {
self.vocab_size
}
fn get_sliding_window(&self) -> Option<usize> {
None
}
}

impl LlamaConfig {
pub fn into_config(self, use_flash_attn: bool) -> Config {
Config {
Expand All @@ -68,80 +41,13 @@ impl LlamaConfig {
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
max_seq_len: MAX_SEQ_LEN,
sliding_window: None,
hidden_act: None,
}
}
}

#[derive(Debug, Clone)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}

impl Config {
pub fn config_7b_v1(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 32,
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}

pub fn config_7b_v2(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 32,
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
}

impl ConfigLike for Config {
fn get_num_kv_heads(&self) -> usize {
self.num_key_value_heads
}
fn get_hidden_size(&self) -> usize {
self.hidden_size
}
fn get_num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn get_num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn get_vocab_size(&self) -> usize {
self.vocab_size
}
fn get_sliding_window(&self) -> Option<usize> {
None
}
}

#[derive(Debug, Clone)]
pub struct Cache {
masks: HashMap<usize, Tensor>,
Expand All @@ -159,9 +65,9 @@ impl Cache {
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
let idx_theta = Tensor::arange(0, config.max_seq_len as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.reshape((config.max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Expand Down
31 changes: 22 additions & 9 deletions src/openai/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
pub mod llama;
pub mod phi3;

pub trait ConfigLike {
fn get_num_kv_heads(&self) -> usize;
fn get_hidden_size(&self) -> usize;
fn get_num_hidden_layers(&self) -> usize;
fn get_num_attention_heads(&self) -> usize;
fn get_vocab_size(&self) -> usize;
fn get_sliding_window(&self) -> Option<usize>;
fn get_head_size(&self) -> usize {
self.get_hidden_size() / self.get_num_attention_heads()
#[derive(Debug, Clone)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub max_seq_len: usize,
pub sliding_window: Option<usize>,
pub hidden_act: Option<candle_nn::Activation>,
}

impl Config {
pub fn get_head_size(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
}
Loading

0 comments on commit 743a8b2

Please sign in to comment.