Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Yi & StableLM models, change default maximum length of generated tokens for smooth chat. #57

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Currently, candle-vllm supports chat serving for the following models.
| #2 | **Mistral** |✅|70 tks/s (7B)|
| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|
| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)|
| #5 | Yi |TBD|TBD|
| #6 | StableLM |TBD|TBD|
| #5 | **Yi** |✅|TBD|
| #6 | **StableLM** |✅|TBD|
| #7 | BigCode/StarCode |TBD|TBD|
| #8 | ChatGLM |TBD|TBD|
| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|
Expand Down Expand Up @@ -133,7 +133,7 @@ For model-specific help, run `cargo run -- --port 2000 <MODEL_TYPE> --help`

For local model weights, run `cargo run --release -- --port 2000 --weight-path /home/llama2_7b/ llama --repeat-last-n 64`, change the path when needed.

`MODEL_TYPE` = ["llama", "mistral", "phi2", "phi3", "qwen2", "gemma"]
`MODEL_TYPE` = ["llama", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"]

`WEIGHT_FILE_PATH` = Corresponding weight path for the given model type

Expand All @@ -158,9 +158,11 @@ For chat streaming, the `stream` flag in chat request need to be set to `True`.
You may supply `penalty` and `temperature` to the model to **prevent potential repetitions**, for example:

```
cargo run --release -- --port 2000 --weight-path /home/mistral_7b/ mistral --repeat-last-n 32 --penalty 1.1 temperature 0.8
cargo run --release -- --port 2000 --weight-path /home/mistral_7b/ mistral --repeat-last-n 32 --penalty 1.1 --temperature 0.8
```

`--max-gen-tokens` parameter is used to control the maximum output tokens per chat response. The value will be set to 1/5 of max_sequence_len by default.

## Report issue
Installing `candle-vllm` is as simple as the following steps. If you have any problems, please create an
[issue](https://github.com/EricLBuehler/candle-lora/issues).
Expand Down
176 changes: 170 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ pub enum ModelSelected {

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the phi2 model (default 2.7b).
Expand All @@ -33,6 +36,9 @@ pub enum ModelSelected {

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the phi3 model (default 3.8b).
Expand All @@ -52,6 +58,9 @@ pub enum ModelSelected {

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the qwen model (default 1.8b).
Expand All @@ -71,6 +80,9 @@ pub enum ModelSelected {

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the gemma model (default 2b).
Expand All @@ -84,6 +96,9 @@ pub enum ModelSelected {

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the mistral model (default 7b).
Expand All @@ -97,6 +112,41 @@ pub enum ModelSelected {

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the Yi model (default 6b).
Yi {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the stable-lm model (default zephyr-3b).
StableLM {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},
}

Expand All @@ -107,36 +157,54 @@ impl ToString for ModelSelected {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "llama".to_string(),
ModelSelected::Phi2 {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "phi2".to_string(),
ModelSelected::Phi3 {
repeat_last_n: _,
temperature: _,
top_k: _,
top_p: _,
penalty: _,
max_gen_tokens: _,
} => "phi3".to_string(),
ModelSelected::Qwen2 {
repeat_last_n: _,
temperature: _,
top_k: _,
top_p: _,
penalty: _,
max_gen_tokens: _,
} => "qwen2".to_string(),
ModelSelected::Gemma {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "gemma".to_string(),
ModelSelected::Mistral {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "mistral".to_string(),
ModelSelected::Yi {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "yi".to_string(),
ModelSelected::StableLM {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "stablelm".to_string(),
}
}
}
Expand All @@ -150,9 +218,17 @@ pub fn get_model_loader<'a>(
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"llama".to_string(),
)),
if model_id.is_some() {
Expand All @@ -165,9 +241,17 @@ pub fn get_model_loader<'a>(
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"phi2".to_string(),
)),
if model_id.is_some() {
Expand All @@ -182,9 +266,17 @@ pub fn get_model_loader<'a>(
top_k,
top_p,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, temperature, top_k, top_p, penalty),
SpecificConfig::new(
repeat_last_n,
temperature,
top_k,
top_p,
penalty,
max_gen_tokens,
),
"phi3".to_string(),
)),
if model_id.is_some() {
Expand All @@ -199,9 +291,17 @@ pub fn get_model_loader<'a>(
top_k,
top_p,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, temperature, top_k, top_p, penalty),
SpecificConfig::new(
repeat_last_n,
temperature,
top_k,
top_p,
penalty,
max_gen_tokens,
),
"qwen2".to_string(),
)),
if model_id.is_some() {
Expand All @@ -214,9 +314,17 @@ pub fn get_model_loader<'a>(
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"gemma".to_string(),
)),
if model_id.is_some() {
Expand All @@ -229,9 +337,17 @@ pub fn get_model_loader<'a>(
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"mistral".to_string(),
)),
if model_id.is_some() {
Expand All @@ -240,6 +356,54 @@ pub fn get_model_loader<'a>(
"mistralai/Mistral-7B-Instruct-v0.3".to_string()
},
),

ModelSelected::Yi {
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"yi".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
} else {
"01-ai/Yi-6B-Chat".to_string()
},
),

ModelSelected::StableLM {
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"stablelm".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
} else {
"stabilityai/stablelm-zephyr-3b".to_string()
},
),
}
}

Expand Down
Loading
Loading