Skip to content

Commit

Permalink
Switch streaming service to axum & standalone generation thread
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingbao committed Jul 24, 2024
1 parent f939b4e commit 0d509e9
Show file tree
Hide file tree
Showing 14 changed files with 916 additions and 975 deletions.
748 changes: 288 additions & 460 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
actix-web = "4.8.0"
axum = { version = "0.7.4", features = ["tokio"] }
utoipa = { version = "4.2", features = ["axum_extras"] }
tower-http = { version = "0.5.1", features = ["cors"]}
flume = "0.10.14"
#actix-web = "4.8.0"
anyhow = "1.0.75"
hyper = { version = "0.14", features = ["full"] }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.6.0" }
candle-examples = { git = "https://github.com/huggingface/candle.git", version = "0.6.0" }
#candle-lora = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ impl ToString for ModelSelected {
}
}

pub fn get_model_loader<'a>(
pub fn get_model_loader(
selected_model: ModelSelected,
model_id: Option<String>,
) -> (Box<dyn ModelLoader<'a>>, String) {
) -> (Box<dyn ModelLoader>, String) {
match selected_model {
ModelSelected::Llama {
repeat_last_n,
Expand Down
53 changes: 24 additions & 29 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use actix_web::middleware::Logger;
use actix_web::web::Data;
use actix_web::{App, HttpServer};
use axum::{
http::{self, Method},
routing::post,
Router,
};
use candle_core::{DType, Device};
use candle_examples;
use candle_vllm::openai::openai_server::chat_completions;
Expand All @@ -12,12 +14,12 @@ use candle_vllm::scheduler::cache_engine::CacheConfig;
use candle_vllm::scheduler::SchedulerConfig;
use candle_vllm::{get_model_loader, hub_load_local_safetensors, ModelSelected};
use clap::Parser;
use futures::lock::Mutex;
use std::sync::Arc;
const SIZE_IN_MB: usize = 1024 * 1024;
use candle_vllm::openai::models::Config;
use std::path::Path;

use tokio::sync::Notify;
use tower_http::cors::{AllowOrigin, CorsLayer};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
Expand Down Expand Up @@ -77,7 +79,7 @@ struct Args {
record_conversation: bool,
}

#[actix_web::main]
#[tokio::main]
async fn main() -> Result<(), APIError> {
let args = Args::parse();
let (loader, model_id) = get_model_loader(args.command, args.model_id.clone());
Expand Down Expand Up @@ -142,42 +144,35 @@ async fn main() -> Result<(), APIError> {
max_num_seqs: args.max_num_seqs,
},
cache_config,
Arc::new(Notify::new()),
)?;

let server_data = OpenAIServerData {
pipeline_config: model.1,
model: Arc::new(Mutex::new(llm_engine)),
model: llm_engine,
record_conversation: args.record_conversation,
device: Device::Cpu,
};

println!("Server started at http://127.0.0.1:{}.", args.port);
if args.verbose {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));

HttpServer::new(move || {
App::new()
.wrap(Logger::default())
.service(chat_completions)
.app_data(Data::new(server_data.clone()))
})
.bind(("127.0.0.1", args.port))
.map_err(|e| APIError::new(e.to_string()))?
.run()

let allow_origin = AllowOrigin::any();
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);

let app = Router::new()
.layer(cors_layer)
.route("/v1/chat/completions", post(chat_completions))
.with_state(Arc::new(server_data));

let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", args.port))
.await
.map_err(|e| APIError::new(e.to_string()))?;
} else {
HttpServer::new(move || {
App::new()
.service(chat_completions)
.app_data(Data::new(server_data.clone()))
})
.bind(("127.0.0.1", args.port))
.map_err(|e| APIError::new(e.to_string()))?
.run()
axum::serve(listener, app)
.await
.map_err(|e| APIError::new(e.to_string()))?;
}

Ok(())
}
7 changes: 3 additions & 4 deletions src/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use candle_core::Device;
use futures::lock::Mutex;
use std::sync::Arc;
use tokenizers::{EncodeInput, Encoding, Tokenizer};
use tokio::sync::Mutex;

use self::{pipelines::llm_engine::LLMEngine, responses::APIError};

Expand Down Expand Up @@ -40,9 +40,8 @@ pub struct PipelineConfig {
pub temperature: f32,
}

#[derive(Clone)]
pub struct OpenAIServerData<'s> {
pub model: Arc<Mutex<LLMEngine<'s>>>,
pub struct OpenAIServerData {
pub model: Arc<Mutex<LLMEngine>>,
pub pipeline_config: PipelineConfig,
pub record_conversation: bool,
pub device: Device,
Expand Down
Loading

0 comments on commit 0d509e9

Please sign in to comment.