diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index 06e2b3d..44f05b8 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -184,17 +184,23 @@ pub async fn chat_completions( let stream_request = request.stream.is_some_and(|x| x); let model_name = request.model.clone(); - //send completion request to inference engine - let mut model = data.model.lock().await; - model.add_request( - token_ids, - request_id.clone(), - SystemTime::now(), - sampling_params, - request.logprobs.unwrap_or(false), - Some(response_tx), - ); - model.notify.notify_one(); + let _ = tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(async move { + { + //send completion request to inference engine + let mut model = data.model.lock().await; + model.add_request( + token_ids, + request_id.clone(), + SystemTime::now(), + sampling_params, + request.logprobs.unwrap_or(false), + Some(response_tx), + ); + model.notify.notify_one(); + } + }); + }); if stream_request { ChatResponder::Streamer( diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index ed121f1..2c6ddc8 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -83,51 +83,52 @@ impl LLMEngine { })); let engine_clone = engine.clone(); - tokio::runtime::Handle::current().block_on(async move { - loop { - notify.notified().await; // Blocking call to wait for notification - let _ = tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - let mut e = engine.lock().await; - let result = e.generate_once().unwrap(); - if result.is_empty() { - continue; - } - for request_id in result.keys() { - e.completion_records - .insert(request_id.to_string(), result[request_id].clone()); - } - finish_notify.notify_one(); - - //chat completion statistics - let overall_usage = ChatCompletionUsageResponse { - request_id: "".to_string(), - created: 0, - completion_tokens: result - .values() - .map(|(_, usage)| usage.completion_tokens) - .sum(), - prompt_tokens: result.values().map(|(_, usage)| usage.prompt_tokens).sum(), - total_tokens: result.values().map(|(_, usage)| usage.total_tokens).sum(), - prompt_time_costs: result - .values() - .map(|(_, usage)| usage.prompt_time_costs) - .max() - .unwrap_or(0), - completion_time_costs: result - .values() - .map(|(_, usage)| usage.completion_time_costs) - .max() - .unwrap_or(0), - }; + tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(async move { + loop { + notify.notified().await; // Blocking call to wait for notification + let _ = tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + let mut e = engine.lock().await; + let result = e.generate_once().unwrap(); + if result.is_empty() { + continue; + } + for request_id in result.keys() { + e.completion_records + .insert(request_id.to_string(), result[request_id].clone()); + } + finish_notify.notify_one(); + + //chat completion statistics + let overall_usage = ChatCompletionUsageResponse { + request_id: "".to_string(), + created: 0, + completion_tokens: result + .values() + .map(|(_, usage)| usage.completion_tokens) + .sum(), + prompt_tokens: result.values().map(|(_, usage)| usage.prompt_tokens).sum(), + total_tokens: result.values().map(|(_, usage)| usage.total_tokens).sum(), + prompt_time_costs: result + .values() + .map(|(_, usage)| usage.prompt_time_costs) + .max() + .unwrap_or(0), + completion_time_costs: result + .values() + .map(|(_, usage)| usage.completion_time_costs) + .max() + .unwrap_or(0), + }; - println!( - "\r\n [{} requests] Prefilling: {} prompt tokens processed in {} seconds", - result.len(), - overall_usage.prompt_tokens, - overall_usage.prompt_time_costs / 1000 - ); + println!( + "\r\n [{} requests] Prefilling: {} prompt tokens processed in {} seconds", + result.len(), + overall_usage.prompt_tokens, + overall_usage.prompt_time_costs / 1000 + ); - println!( + println!( "\r\n [{} requests] Decoding: {} tokens processed in {} seconds ({} tokens/s)", result.len(), overall_usage.completion_tokens, @@ -139,7 +140,8 @@ impl LLMEngine { 1 } ); - } + } + }) }); Ok(engine_clone)