Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More elegant way for handing non-streaming finish signal. #66

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,23 @@ async fn main() -> Result<(), APIError> {
dtype: config.kv_cache_dtype,
};
println!("Cache config {:?}", cache_config);

let finish_notify = Arc::new(Notify::new());
let llm_engine = LLMEngine::new(
model.0,
SchedulerConfig {
max_num_seqs: args.max_num_seqs,
},
cache_config,
Arc::new(Notify::new()),
finish_notify.clone(),
)?;

let server_data = OpenAIServerData {
pipeline_config: model.1,
model: llm_engine,
record_conversation: args.record_conversation,
device: Device::Cpu,
finish_notify: finish_notify.clone(),
};

println!("Server started at http://127.0.0.1:{}.", args.port);
Expand Down
3 changes: 2 additions & 1 deletion src/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use candle_core::Device;
use std::sync::Arc;
use tokenizers::{EncodeInput, Encoding, Tokenizer};
use tokio::sync::Mutex;
use tokio::sync::{Mutex, Notify};

use self::{pipelines::llm_engine::LLMEngine, responses::APIError};

Expand Down Expand Up @@ -45,6 +45,7 @@ pub struct OpenAIServerData {
pub pipeline_config: PipelineConfig,
pub record_conversation: bool,
pub device: Device,
pub finish_notify: Arc<Notify>,
}

pub mod conversation;
Expand Down
9 changes: 8 additions & 1 deletion src/openai/openai_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,15 @@ pub async fn chat_completions(
)
} else {
// wait until current response finished
tokio::time::sleep(Duration::from_millis(100)).await; //permits generation thread to work
data.finish_notify.notified().await;
let model = data.model.lock().await;
if !model.completion_records.contains_key(&request_id) {
return ChatResponder::ModelError(APIError::from(format!(
"Unable to generate response for request {}",
request_id
)));
}

let choices = &model.completion_records[&request_id].0;
let usage = &model.completion_records[&request_id].1;

Expand Down
4 changes: 4 additions & 0 deletions src/openai/pipelines/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub struct LLMEngine {
cache_engine: CacheEngine,
sliding_window: Option<usize>,
pub notify: Arc<Notify>,
pub finish_notify: Arc<Notify>,
pub completion_records: HashMap<String, (Vec<ChatChoice>, ChatCompletionUsageResponse)>,
}

Expand All @@ -58,6 +59,7 @@ impl LLMEngine {
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
notify: Arc<Notify>,
finish_notify: Arc<Notify>,
) -> Result<Arc<Mutex<Self>>, APIError> {
let cache_engine = CacheEngine::new(
pipeline.get_model_config(),
Expand All @@ -76,6 +78,7 @@ impl LLMEngine {
cache_engine,
sliding_window,
notify: notify.clone(),
finish_notify: finish_notify.clone(),
completion_records: HashMap::new(),
}));
let engine_clone = engine.clone();
Expand Down Expand Up @@ -133,6 +136,7 @@ impl LLMEngine {
);
e.completion_records
.insert(request_id.clone(), (choices, usage));
finish_notify.notify_one();
}
});
});
Expand Down
3 changes: 3 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async fn test_llama() -> Result<(), APIError> {
None,
)?;
let model = loader.load_model(paths, DType::F16, Device::Cpu)?;
let finish_notify = Arc::new(Notify::new());
let llm_engine = LLMEngine::new(
model.0,
SchedulerConfig { max_num_seqs: 256 },
Expand All @@ -45,13 +46,15 @@ async fn test_llama() -> Result<(), APIError> {
dtype: DType::F16,
},
Arc::new(Notify::new()),
finish_notify.clone(),
)?;

let server_data = OpenAIServerData {
pipeline_config: model.1,
model: llm_engine,
device: Device::Cpu,
record_conversation: false,
finish_notify: finish_notify.clone(),
};

let allow_origin = AllowOrigin::any();
Expand Down
Loading