Skip to content

Commit

Permalink
Merge pull request #67 from EricLBuehler/develop
Browse files Browse the repository at this point in the history
LLaMa3.1 chat completion
  • Loading branch information
guoqingbao authored Jul 26, 2024
2 parents 8476f17 + 021a033 commit e922750
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 27 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, BF16)
|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3** ||74 tks/s (7B)|
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** ||74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)|
| #2 | **Mistral** ||70 tks/s (7B)|
| #3 | **Phi (v1, v1.5, v2)** ||97 tks/s (2.7B, F32+BF16)|
| #4 | **Phi-3 (3.8B, 7B)** ||107 tks/s (3.8B)|
Expand Down Expand Up @@ -55,6 +55,11 @@ You may also run specific model using huggingface model-id, e.g.,
cargo run --release -- --port 2000 --model-id meta-llama/Llama-2-7b-chat-hf llama --repeat-last-n 64
```

Run latest LLaMa3.1 using local weights

```
cargo run --release -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --repeat-last-n 64
```
### Step 2:

#### Option 1: Chat with ChatUI (recommended)
Expand Down
45 changes: 45 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ pub enum ModelSelected {
max_gen_tokens: Option<usize>,
},

/// Select the llama3 model (default llama3.1-8b).
Llama3 {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the phi2 model (default 2.7b).
Phi2 {
/// Control the application of repeat penalty for the last n tokens
Expand Down Expand Up @@ -159,6 +175,12 @@ impl ToString for ModelSelected {
penalty: _,
max_gen_tokens: _,
} => "llama".to_string(),
ModelSelected::Llama3 {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "llama3".to_string(),
ModelSelected::Phi2 {
repeat_last_n: _,
temperature: _,
Expand Down Expand Up @@ -237,6 +259,29 @@ pub fn get_model_loader(
"meta-llama/Llama-2-7b-chat-hf".to_string()
},
),
ModelSelected::Llama3 {
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"llama3".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
} else {
"meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()
},
),
ModelSelected::Phi2 {
repeat_last_n,
temperature,
Expand Down
27 changes: 27 additions & 0 deletions src/openai/conversation/default_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum SeparatorStyle {
NoColonTwo,
AddNewLineSingle,
Llama,
Llama3,
Phi,
Qwen2,
Gemma,
Expand Down Expand Up @@ -248,6 +249,32 @@ impl Conversation for DefaultConversation {
accum
}

SeparatorStyle::Llama3 => {
let mut accum = "<|begin_of_text|>".to_string();
for (i, message) in self.messages.iter().enumerate() {
let Message((_role, message)) = message;
if _role.clone() == self.roles.0 {
//user message
if let Some(message) = message {
accum += &format!(
"<|start_header_id|>user<|end_header_id|>\n\n {message} <|eot_id|>"
);
} else {
accum +=
&format!("<|start_header_id|>user<|end_header_id|>\n\n <|eot_id|>");
}
} else if _role.clone() == self.roles.1 {
//assistant message
if let Some(message) = message {
accum += &format!("<|start_header_id|>assistant<|end_header_id|>\n\n {message} <|eot_id|>");
}
} else if i == 0 && !system_prompt.is_empty() {
accum += &system_prompt;
}
}
accum
}

SeparatorStyle::Phi => {
let mut accum = "".to_string();
for (i, message) in self.messages.iter().enumerate() {
Expand Down
6 changes: 3 additions & 3 deletions src/openai/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_core as candle;
use candle_nn::Activation;
use candle_nn::{linear_b, linear_no_bias as linear, Linear, RmsNorm, VarBuilder};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

#[derive(serde::Deserialize, Debug, Clone)]
pub struct GemmaConfig {
pub attention_bias: bool,
Expand Down Expand Up @@ -45,8 +45,8 @@ impl GemmaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.eos_token_id as u32))),
max_seq_len: self.max_position_embeddings.unwrap_or(4096),
sliding_window: None,
hidden_act: hidden_act,
Expand Down
8 changes: 5 additions & 3 deletions src/openai/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use candle_core as candle;
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use candle_transformers::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
pub const MAX_SEQ_LEN: usize = 4096;
use crate::openai::models::TokenID;
use std::iter::zip;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
Expand All @@ -18,8 +19,9 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub bos_token_id: TokenID,
pub eos_token_id: TokenID,
pub max_position_embeddings: Option<usize>,
}

fn default_rope() -> f32 {
Expand All @@ -40,7 +42,7 @@ impl LlamaConfig {
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
max_seq_len: MAX_SEQ_LEN,
max_seq_len: self.max_position_embeddings.unwrap_or(MAX_SEQ_LEN),
sliding_window: None,
hidden_act: None,
tie_word_embeddings: false,
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -37,8 +38,8 @@ impl MistralConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.eos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
9 changes: 7 additions & 2 deletions src/openai/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ use std::collections::HashMap;
#[derive(Deserialize, Debug, Clone)]
pub struct RopeScaling(#[serde(with = "either::serde_untagged")] pub Either<Vec<f64>, String>);

#[derive(Deserialize, Debug, Clone)]
pub struct TokenID(
#[serde(with = "either::serde_untagged")] pub Either<Option<u32>, Option<Vec<u32>>>,
);

#[derive(Debug, Clone)]
pub struct Config {
pub hidden_size: usize,
Expand All @@ -25,8 +30,8 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub bos_token_id: TokenID,
pub eos_token_id: TokenID,
pub max_seq_len: usize,
pub sliding_window: Option<usize>,
pub hidden_act: Option<candle_nn::Activation>,
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::{
layer_norm, linear_no_bias as linear, Embedding, LayerNorm, Linear,
};
use either::Either;
use serde::Deserialize;
use std::iter::zip;

Expand Down Expand Up @@ -41,8 +42,8 @@ impl Phi2Config {
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
rms_norm_eps: self.layer_norm_eps,
rope_theta: self.rope_theta,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
bos_token_id: super::TokenID(Either::Left(self.bos_token_id)),
eos_token_id: super::TokenID(Either::Left(self.eos_token_id)),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
4 changes: 2 additions & 2 deletions src/openai/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl PhiConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
bos_token_id: super::TokenID(Either::Left(self.bos_token_id)),
eos_token_id: super::TokenID(Either::Left(self.eos_token_id)),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_core as candle;
use candle_nn::VarBuilder;
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -40,8 +41,8 @@ impl QwenConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: Some(self.sliding_window),
hidden_act: Some(self.hidden_act),
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/stable_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder};
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -41,8 +42,8 @@ impl StableLMConfig {
rms_norm_eps: self.norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/yi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -37,8 +38,8 @@ impl YiConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
37 changes: 31 additions & 6 deletions src/openai/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{get_token, ModelLoader, ModelPaths, ModulePipeline, TokenOrFinishReason};
use crate::openai::models::TokenID;
use crate::openai::sampling_params::{Logprobs, TopLogprob};
use crate::scheduler::sequence::SequenceGroup;
use crate::{
Expand Down Expand Up @@ -31,6 +32,7 @@ use candle_core::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use either::Either;
use either::Either::{Left, Right};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use std::{path::PathBuf, sync::Arc};
Expand Down Expand Up @@ -170,7 +172,7 @@ impl ModelLoader for DefaultLoader {
let specific_args = self.config.clone();

let config = match self.name.as_str() {
"llama" => {
"llama" | "llama3" => {
let config: LlamaConfig = try_api!(serde_json::from_slice(&try_api!(
std::fs::read(paths.get_config_filename())
),));
Expand Down Expand Up @@ -238,6 +240,10 @@ impl ModelLoader for DefaultLoader {
LLMModel::LLAMA(try_api!(Llama::load(vb, &config, dtype, &device))),
SeparatorStyle::Llama,
),
"llama3" => (
LLMModel::LLAMA(try_api!(Llama::load(vb, &config, dtype, &device))),
SeparatorStyle::Llama3,
),
"phi2" => (
LLMModel::Phi2(try_api!(Phi2::new(vb, &config, dtype, &device))),
SeparatorStyle::Phi,
Expand Down Expand Up @@ -296,13 +302,32 @@ impl ModelLoader for DefaultLoader {

println!("{:?}", pipeline_config);

let eos_token = match tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap(),
};
let mut stop_token_ids = Vec::<u32>::new();
stop_token_ids.push(eos_token);

match &config.eos_token_id {
//eos_token defined in the config
TokenID(Either::Left(eos_token)) => {
if let Some(tk) = eos_token {
stop_token_ids.push(*tk);
}
}
TokenID(Either::Right(eos_token_list)) => {
if let Some(tks) = eos_token_list {
stop_token_ids.extend(tks)
}
}
}

if stop_token_ids.len() == 0 {
//if no eos_token defined in the config, use default
let eos_token = match tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
_ => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap_or(0),
};
stop_token_ids.push(eos_token);
}

//custome stop tokens

Check warning on line 330 in src/openai/pipelines/pipeline.rs

View workflow job for this annotation

GitHub Actions / Typos

"custome" should be "custom" or "customs" or "costume" or "customer".
if let Some(custom_stop) = &config.custom_stop_tokens {
for stop in custom_stop {
match tokenizer.get_token(&stop) {
Expand Down

0 comments on commit e922750

Please sign in to comment.