From 97c563091f4f5fe13be729fb80f78a23b77f520f Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Mon, 8 Jul 2024 12:12:06 +0800 Subject: [PATCH] Error prompt for requested message exceeds model capacity --- src/openai/mod.rs | 3 ++- src/openai/openai_server.rs | 22 +++++++++++----------- src/openai/pipelines/pipeline.rs | 23 +++++++++++++++++++++-- src/openai/responses.rs | 11 +++++++++-- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/src/openai/mod.rs b/src/openai/mod.rs index 20669ed..191828d 100644 --- a/src/openai/mod.rs +++ b/src/openai/mod.rs @@ -31,9 +31,10 @@ where } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PipelineConfig { pub max_model_len: usize, + pub default_max_tokens: usize, } #[derive(Clone)] diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index 28b7f2f..fd05693 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -76,22 +76,20 @@ async fn check_length( .map_err(APIError::from)? }; - let max_tokens = if let Some(max_toks) = request.max_tokens { - max_toks - } else { - data.pipeline_config.max_model_len - token_ids.len() - }; + let max_gen_tokens = request + .max_tokens + .unwrap_or(data.pipeline_config.default_max_tokens); - if token_ids.len() + max_tokens > data.pipeline_config.max_model_len { + if token_ids.len() + max_gen_tokens > data.pipeline_config.max_model_len { Err(APIError::new(format!( "This model's maximum context length is {} tokens. \ However, you requested {} tokens ({} in the messages, \ - {} in the completion). Please reduce the length of the \ - messages or completion.", + {} in the completion). \nPlease clear the chat history or reduce the length of the \ + messages.", data.pipeline_config.max_model_len, - max_tokens + token_ids.len(), + max_gen_tokens + token_ids.len(), token_ids.len(), - max_tokens + max_gen_tokens ))) } else { Ok(token_ids) @@ -157,7 +155,9 @@ async fn chat_completions( request.stop.clone(), request.stop_token_ids.clone().unwrap_or_default(), request.ignore_eos.unwrap_or(false), - request.max_tokens.unwrap_or(1024), + request + .max_tokens + .unwrap_or(data.pipeline_config.default_max_tokens), None, None, request.skip_special_tokens.unwrap_or(true), diff --git a/src/openai/pipelines/pipeline.rs b/src/openai/pipelines/pipeline.rs index 4885eb7..df214ca 100644 --- a/src/openai/pipelines/pipeline.rs +++ b/src/openai/pipelines/pipeline.rs @@ -32,6 +32,8 @@ use std::{iter::zip, path::PathBuf, sync::Arc}; use tokenizers::Tokenizer; const EOS_TOKEN: &str = ""; const SAMPLING_SEED: u64 = 299792458; +const MIN_GEN_TOKENS: usize = 128; +const MAX_GEN_TOKENS: usize = 4096; #[derive(Debug, Clone)] pub struct SpecificConfig { @@ -160,6 +162,8 @@ impl<'a> ModelLoader<'a> for DefaultLoader { _ => panic!(""), }; + println!("Model {:?}", config); + println!("Loading {} model.", self.name); let vb = match unsafe { @@ -192,9 +196,24 @@ impl<'a> ModelLoader<'a> for DefaultLoader { println!("Done loading."); - //max is https://huggingface.co/docs/transformers/model_doc/llama2#transformers.LlamaConfig.max_position_embeddings + //max and min number of tokens generated per request + let mut default_max_tokens = config.max_seq_len / 10; + if default_max_tokens < MIN_GEN_TOKENS { + default_max_tokens = MIN_GEN_TOKENS; + } else if default_max_tokens > MAX_GEN_TOKENS { + default_max_tokens = MAX_GEN_TOKENS; + } + let pipeline_config = PipelineConfig { - max_model_len: 4096, + max_model_len: config.max_seq_len, + default_max_tokens, + }; + + println!("{:?}", pipeline_config); + + let eos_token = match tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap(), }; let eos_token = match tokenizer.get_token("<|endoftext|>") { diff --git a/src/openai/responses.rs b/src/openai/responses.rs index 545cb3b..5d8baa5 100644 --- a/src/openai/responses.rs +++ b/src/openai/responses.rs @@ -1,5 +1,5 @@ use crate::openai::sampling_params::Logprobs; -use actix_web::error; +use actix_web::{error, HttpResponse}; use derive_more::{Display, Error}; use serde::{Deserialize, Serialize}; @@ -10,7 +10,14 @@ pub struct APIError { data: String, } -impl error::ResponseError for APIError {} +impl error::ResponseError for APIError { + fn error_response(&self) -> HttpResponse { + //pack error to json so that client can handle it + HttpResponse::BadRequest() + .content_type("application/json") + .json(self.data.to_string()) + } +} impl APIError { pub fn new(data: String) -> Self {