Skip to content

Commit

Permalink
Support Gemma model & remove repeat_kv (replaced with broadcast matmu…
Browse files Browse the repository at this point in the history
…l in prefilling stage)
  • Loading branch information
guoqingbao committed Jul 8, 2024
1 parent 97c5630 commit 54c634d
Show file tree
Hide file tree
Showing 10 changed files with 498 additions and 23 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, BF16)
|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3** ||73 tks/s (7B)|
| #1 | **LLAMA/LLAMA2/LLaMa3** ||74 tks/s (7B)|
| #2 | Mistral |TBD|TBD|
| #3 | Phi (v1, v1.5, v2) |TBD|TBD|
| #4 | **Phi-3 (3.8B, 7B)** ||102 tks/s (3.8B)|
| #4 | **Phi-3 (3.8B, 7B)** ||107 tks/s (3.8B)|
| #5 | Yi |TBD|TBD|
| #6 | StableLM |TBD|TBD|
| #7 | BigCode/StarCode |TBD|TBD|
| #8 | ChatGLM |TBD|TBD|
| #9 | **QWen2 (1.8B, 7B)** ||148 tks/s (1.8B)|
| #10 | Google Gemma |TBD|TBD|
| #10 | **Google Gemma** ||130 tks/s (2B)|
| #11 | Blip-large (Multimodal) |TBD|TBD|
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|

Expand Down
19 changes: 19 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ pub enum ModelSelected {
#[arg(long)]
repeat_last_n: usize,
},

/// Select the gemma model (default 2b).
Gemma {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: usize,
},
}

impl ToString for ModelSelected {
Expand All @@ -37,6 +44,7 @@ impl ToString for ModelSelected {
ModelSelected::Llama { repeat_last_n: _ } => "llama".to_string(),
ModelSelected::Phi3 { repeat_last_n: _ } => "phi3".to_string(),
ModelSelected::Qwen2 { repeat_last_n: _ } => "qwen2".to_string(),
ModelSelected::Gemma { repeat_last_n: _ } => "gemma".to_string(),
}
}
}
Expand Down Expand Up @@ -79,6 +87,17 @@ pub fn get_model_loader<'a>(
"Qwen/Qwen1.5-1.8B-Chat".to_string()
},
),
ModelSelected::Gemma { repeat_last_n } => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n),
"gemma".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
} else {
"google/gemma-2b-it".to_string()
},
),
}
}

Expand Down
16 changes: 16 additions & 0 deletions src/openai/conversation/default_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum SeparatorStyle {
Llama,
Phi,
Qwen2,
Gemma,
ChatGLM,
ChatML,
ChatIntern,
Expand Down Expand Up @@ -290,6 +291,21 @@ impl Conversation for DefaultConversation {
accum
}

SeparatorStyle::Gemma => {
let mut accum = "".to_string();
for (_, message) in self.messages.iter().enumerate() {
let Message((_role, message)) = message;
if let Some(message) = message {
accum +=
&format!("<bos><start_of_turn>{_role}\n {message} <end_of_turn>\n");
} else {
accum += &format!("<start_of_turn>{_role}\n <end_of_turn>\n");
}
}
accum += "<start_of_turn>model\n";
accum
}

SeparatorStyle::ChatGLM => {
let round_add_n = if self.name == "chatglm2" { 1 } else { 0 };

Expand Down
Loading

0 comments on commit 54c634d

Please sign in to comment.