Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable kvcache & fix repeat chat history #41

Merged
merged 8 commits into from
Jun 20, 2024
Merged
41 changes: 30 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ struct Args {

#[arg(long, default_value_t = false)]
cpu: bool,

/// Available GPU memory for kvcache (MB)
#[arg(long, default_value_t = 4096)]
kvcache_mem: usize,

/// Record conversation (default false, the client need to record chat history)
#[arg(long)]
record_conversation: bool,
}

#[actix_web::main]
Expand All @@ -72,33 +80,44 @@ async fn main() -> Result<(), APIError> {
_ => loader.download_model(model_id, None, args.hf_token, args.hf_token_path)?,
};

let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
let (dtype, dsize) = match args.dtype.as_deref() {
guoqingbao marked this conversation as resolved.
Show resolved Hide resolved
Some("f16") => (DType::F16, 2),
Some("bf16") => (DType::BF16, 2),
Some("f32") => (DType::F32, 4),
Some(dtype) => panic!("Unsupported dtype {dtype}"),
None => DType::BF16,
None => (DType::BF16, 2),
};

let device = candle_examples::device(args.cpu).unwrap();
let model = loader.load_model(paths, dtype, device)?;
let config = model.0.get_model_config();
let num_gpu_blocks = args.kvcache_mem * 1024 * 1024
guoqingbao marked this conversation as resolved.
Show resolved Hide resolved
/ dsize
/ args.block_size
/ config.get_num_kv_heads()
/ config.get_head_size()
/ config.get_num_hidden_layers()
/ 2;
let cache_config = CacheConfig {
block_size: args.block_size,
num_gpu_blocks: Some(num_gpu_blocks),
num_cpu_blocks: Some(32),
guoqingbao marked this conversation as resolved.
Show resolved Hide resolved
fully_init: true,
};
println!("Cache config {:?}", cache_config);

let llm_engine = LLMEngine::new(
model.0,
SchedulerConfig {
max_num_seqs: args.max_num_seqs,
},
CacheConfig {
block_size: args.block_size,
num_gpu_blocks: Some(64),
num_cpu_blocks: Some(64),
fully_init: true,
},
cache_config,
)?;

let server_data = OpenAIServerData {
pipeline_config: model.1,
model: Arc::new(Mutex::new(llm_engine)),
record_conversation: args.record_conversation,
device: Device::Cpu,
};

Expand Down
3 changes: 3 additions & 0 deletions src/openai/conversation/default_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ impl Conversation for DefaultConversation {
&self.roles
}

fn clear_message(&mut self) {
self.messages.clear()
}
/// Convert this conversation to a String prompt
fn get_prompt(&mut self) -> String {
let system_prompt = self.system_template.format(&[self.system_message.clone()]);
Expand Down
2 changes: 2 additions & 0 deletions src/openai/conversation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ pub trait Conversation {
fn get_roles(&self) -> &(String, String);

fn get_prompt(&mut self) -> String;

fn clear_message(&mut self);
}
1 change: 1 addition & 0 deletions src/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct PipelineConfig {
pub struct OpenAIServerData<'s> {
pub model: Arc<Mutex<LLMEngine<'s>>>,
pub pipeline_config: PipelineConfig,
pub record_conversation: bool,
pub device: Device,
}

Expand Down
5 changes: 4 additions & 1 deletion src/openai/openai_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ async fn get_gen_prompt(
request: &web::Json<ChatCompletionRequest>,
) -> Result<String, APIError> {
let mut model = data.model.lock().unwrap();
let conversation = model.get_mut_pipeline().get_conversation();
let conversation = model
.get_mut_pipeline()
.get_conversation(data.record_conversation);

match &request.messages {
Messages::Literal(msg) => {
Expand Down Expand Up @@ -129,6 +131,7 @@ async fn chat_completions(
return Either::Left(Err(prompt.err().unwrap()));
}
let prompt = prompt.unwrap();
println!("\n\n\nPrompt {:?}", prompt);

let token_ids = check_length(&request, prompt, &data);
if token_ids.is_err() {
Expand Down
5 changes: 4 additions & 1 deletion src/openai/pipelines/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ impl<'s> ModulePipeline<'s> for LlamaPipeline {
&self.tokenizer
}

fn get_conversation(&mut self) -> &mut dyn Conversation {
fn get_conversation(&mut self, with_history: bool) -> &mut dyn Conversation {
if !with_history {
self.conversation.clear_message();
}
&mut self.conversation
}

Expand Down
2 changes: 1 addition & 1 deletion src/openai/pipelines/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ impl<'a> LLMEngine<'a> {
.map_err(APIError::from)?;
let choice = ChatChoice {
message: ChatChoiceData {
role: self.pipeline.get_conversation().get_roles().0.clone(),
role: self.pipeline.get_conversation(true).get_roles().0.clone(),
content: Some(data),
},
finish_reason: Some(seq.deref_mut().get_finish_reason().clone()),
Expand Down
2 changes: 1 addition & 1 deletion src/openai/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait ModulePipeline<'s>: Send + Sync {

fn tokenizer(&self) -> &TokenOutputStream;

fn get_conversation(&mut self) -> &mut dyn Conversation;
fn get_conversation(&mut self, with_history: bool) -> &mut dyn Conversation;

fn get_model_config(&self) -> Box<dyn ConfigLike>;

Expand Down
2 changes: 1 addition & 1 deletion src/scheduler/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
try_api,
};

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct CacheConfig {
pub block_size: usize,
pub num_gpu_blocks: Option<usize>, // Set after profiling init
Expand Down
Loading