Skip to content

Commit

Permalink
Error prompt for requested message exceeds model capacity
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingbao committed Jul 8, 2024
1 parent 211346e commit 97c5630
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ where
}
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct PipelineConfig {
pub max_model_len: usize,
pub default_max_tokens: usize,
}

#[derive(Clone)]
Expand Down
22 changes: 11 additions & 11 deletions src/openai/openai_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,20 @@ async fn check_length(
.map_err(APIError::from)?
};

let max_tokens = if let Some(max_toks) = request.max_tokens {
max_toks
} else {
data.pipeline_config.max_model_len - token_ids.len()
};
let max_gen_tokens = request
.max_tokens
.unwrap_or(data.pipeline_config.default_max_tokens);

if token_ids.len() + max_tokens > data.pipeline_config.max_model_len {
if token_ids.len() + max_gen_tokens > data.pipeline_config.max_model_len {
Err(APIError::new(format!(
"This model's maximum context length is {} tokens. \
However, you requested {} tokens ({} in the messages, \
{} in the completion). Please reduce the length of the \
messages or completion.",
{} in the completion). \nPlease clear the chat history or reduce the length of the \
messages.",
data.pipeline_config.max_model_len,
max_tokens + token_ids.len(),
max_gen_tokens + token_ids.len(),
token_ids.len(),
max_tokens
max_gen_tokens
)))
} else {
Ok(token_ids)
Expand Down Expand Up @@ -157,7 +155,9 @@ async fn chat_completions(
request.stop.clone(),
request.stop_token_ids.clone().unwrap_or_default(),
request.ignore_eos.unwrap_or(false),
request.max_tokens.unwrap_or(1024),
request
.max_tokens
.unwrap_or(data.pipeline_config.default_max_tokens),
None,
None,
request.skip_special_tokens.unwrap_or(true),
Expand Down
23 changes: 21 additions & 2 deletions src/openai/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use std::{iter::zip, path::PathBuf, sync::Arc};
use tokenizers::Tokenizer;
const EOS_TOKEN: &str = "</s>";
const SAMPLING_SEED: u64 = 299792458;
const MIN_GEN_TOKENS: usize = 128;
const MAX_GEN_TOKENS: usize = 4096;

#[derive(Debug, Clone)]
pub struct SpecificConfig {
Expand Down Expand Up @@ -160,6 +162,8 @@ impl<'a> ModelLoader<'a> for DefaultLoader {
_ => panic!(""),
};

println!("Model {:?}", config);

println!("Loading {} model.", self.name);

let vb = match unsafe {
Expand Down Expand Up @@ -192,9 +196,24 @@ impl<'a> ModelLoader<'a> for DefaultLoader {

println!("Done loading.");

//max is https://huggingface.co/docs/transformers/model_doc/llama2#transformers.LlamaConfig.max_position_embeddings
//max and min number of tokens generated per request
let mut default_max_tokens = config.max_seq_len / 10;
if default_max_tokens < MIN_GEN_TOKENS {
default_max_tokens = MIN_GEN_TOKENS;
} else if default_max_tokens > MAX_GEN_TOKENS {
default_max_tokens = MAX_GEN_TOKENS;
}

let pipeline_config = PipelineConfig {
max_model_len: 4096,
max_model_len: config.max_seq_len,
default_max_tokens,
};

println!("{:?}", pipeline_config);

let eos_token = match tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap(),
};

let eos_token = match tokenizer.get_token("<|endoftext|>") {
Expand Down
11 changes: 9 additions & 2 deletions src/openai/responses.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::openai::sampling_params::Logprobs;
use actix_web::error;
use actix_web::{error, HttpResponse};
use derive_more::{Display, Error};

use serde::{Deserialize, Serialize};
Expand All @@ -10,7 +10,14 @@ pub struct APIError {
data: String,
}

impl error::ResponseError for APIError {}
impl error::ResponseError for APIError {
fn error_response(&self) -> HttpResponse {
//pack error to json so that client can handle it
HttpResponse::BadRequest()
.content_type("application/json")
.json(self.data.to_string())
}
}

impl APIError {
pub fn new(data: String) -> Self {
Expand Down

0 comments on commit 97c5630

Please sign in to comment.