diff --git a/Cargo.lock b/Cargo.lock index 6b5013b28..753c6b925 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,6 +37,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.12" @@ -155,7 +170,7 @@ checksum = "771051cdc7eec2dc1b23fbf870bb7fbb89136fe374227c875e377f1eed99a429" dependencies = [ "futures", "generational-arena", - "parking_lot", + "parking_lot 0.12.1", "slotmap", ] @@ -302,6 +317,18 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "windows-targets 0.52.3", +] + [[package]] name = "clap" version = "4.5.1" @@ -503,7 +530,7 @@ dependencies = [ "hashbrown 0.14.0", "lock_api", "once_cell", - "parking_lot_core", + "parking_lot_core 0.9.8", ] [[package]] @@ -1029,6 +1056,29 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -1086,6 +1136,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" dependencies = [ "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", ] [[package]] @@ -1234,6 +1287,8 @@ dependencies = [ "opentelemetry-otlp", "rand", "reqwest", + "reqwest-middleware", + "reqwest-retry", "serde", "serde_json", "thiserror", @@ -1506,7 +1561,7 @@ dependencies = [ "hyper", "muxado", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "regex", "rustls-pemfile", "serde", @@ -1568,6 +1623,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -1813,6 +1877,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.1" @@ -1820,7 +1895,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.8", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -2199,6 +2288,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", @@ -2216,6 +2306,55 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-middleware" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a3e86aa6053e59030e7ce2d2a3b258dd08fc2d337d52f73f6cb480f5858690" +dependencies = [ + "anyhow", + "async-trait", + "http", + "reqwest", + "serde", + "task-local-extensions", + "thiserror", +] + +[[package]] +name = "reqwest-retry" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cadced6a67c5c2d1c819cc2d7e6ddf066f32b9b6a04f8866203ceeb44b79c37f" +dependencies = [ + "anyhow", + "async-trait", + "chrono", + "futures", + "getrandom", + "http", + "hyper", + "parking_lot 0.11.2", + "reqwest", + "reqwest-middleware", + "retry-policies", + "task-local-extensions", + "tokio", + "tracing", + "wasm-timer", +] + +[[package]] +name = "retry-policies" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "493b4243e32d6eedd29f9a398896e35c6943a123b55eec97dcaee98310d25810" +dependencies = [ + "anyhow", + "chrono", + "rand", +] + [[package]] name = "ring" version = "0.16.20" @@ -2662,6 +2801,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "task-local-extensions" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba323866e5d033818e3240feeb9f7db2c4296674e4d9e16b97b7bf8f490434e8" +dependencies = [ + "pin-utils", +] + [[package]] name = "tempfile" version = "3.6.0" @@ -2796,7 +2944,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot", + "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2", @@ -3415,6 +3563,21 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wasm-timer" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be0ecb0db480561e9a7642b5d3e4187c128914e58aa84330b9493e3eb68c5e7f" +dependencies = [ + "futures", + "js-sys", + "parking_lot 0.11.2", + "pin-utils", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.64" @@ -3486,6 +3649,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.3", +] + [[package]] name = "windows-sys" version = "0.45.0" diff --git a/router/Cargo.toml b/router/Cargo.toml index 079582ada..ecc9eae64 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -30,6 +30,8 @@ opentelemetry = { version = "0.19.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.12.0" rand = "0.8.5" reqwest = { version = "0.11.14", features = [] } +reqwest-middleware = "0.2.4" +reqwest-retry = "0.4.0" serde = "1.0.152" serde_json = { version = "1.0.93", features = ["preserve_order"] } thiserror = "1.0.38" diff --git a/router/src/lib.rs b/router/src/lib.rs index c3ca10f87..61f627f77 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -66,6 +66,8 @@ pub struct Info { pub sha: Option<&'static str>, #[schema(nullable = true, example = "null")] pub docker_label: Option<&'static str>, + #[schema(nullable = true, example = "http://localhost:8899")] + pub request_logger_url: Option, } #[derive(Clone, Debug, Deserialize, ToSchema, Default)] diff --git a/router/src/server.rs b/router/src/server.rs index 87f6b70f6..efc4ce3b1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,6 +2,7 @@ use crate::adapter::{extract_adapter_params, BASE_MODEL_ADAPTER_ID}; use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; +use crate::json; use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, @@ -10,7 +11,7 @@ use crate::{ HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; -use axum::http::{HeaderMap, Method, StatusCode}; +use axum::http::{request, HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -21,21 +22,24 @@ use futures::Stream; use lorax_client::{ShardInfo, ShardedClient}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use once_cell::sync::OnceCell; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; +use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; use tokenizers::Tokenizer; use tokio::signal; +use tokio::sync::mpsc; use tokio::time::Instant; use tower_http::cors::{ AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, CorsLayer, ExposeHeaders, }; use tracing::{info_span, instrument, Instrument}; +use utoipa::openapi::info; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -static MODEL_ID: OnceCell = OnceCell::new(); pub static DEFAULT_ADAPTER_SOURCE: OnceCell = OnceCell::new(); /// Generate tokens if `stream == false` or a stream of token if `stream == true` @@ -64,6 +68,8 @@ example = json ! ({"error": "Incomplete generation"})), async fn compat_generate( default_return_full_text: Extension, infer: Extension, + info: Extension, + request_logger_sender: Extension>>, req_headers: HeaderMap, req: Json, ) -> Result)> { @@ -76,11 +82,24 @@ async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer, req_headers, Json(req.into())) - .await - .into_response()) + Ok(generate_stream( + infer, + info, + request_logger_sender, + req_headers, + Json(req.into()), + ) + .await + .into_response()) } else { - let (headers, generation) = generate(infer, req_headers, Json(req.into())).await?; + let (headers, generation) = generate( + infer, + info, + request_logger_sender, + req_headers, + Json(req.into()), + ) + .await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation.0])).into_response()) } @@ -112,11 +131,13 @@ example = json ! ({"error": "Incomplete generation"})), async fn completions_v1( default_return_full_text: Extension, infer: Extension, + info: Extension, + request_logger_sender: Extension>>, req_headers: HeaderMap, req: Json, ) -> Result)> { let mut req = req.0; - if req.model == MODEL_ID.get().unwrap().as_str() { + if req.model == info.model_id.as_str() { // Allow user to specify the base model, but treat it as an empty adapter_id tracing::info!("Replacing base model {0} with empty adapter_id", req.model); req.model = "".to_string(); @@ -142,11 +163,25 @@ async fn completions_v1( ) }; - let (headers, stream) = - generate_stream_with_callback(infer, req_headers, Json(gen_req.into()), callback).await; + let (headers, stream) = generate_stream_with_callback( + infer, + info, + request_logger_sender, + req_headers, + Json(gen_req.into()), + callback, + ) + .await; Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) } else { - let (headers, generation) = generate(infer, req_headers, Json(gen_req.into())).await?; + let (headers, generation) = generate( + infer, + info, + request_logger_sender, + req_headers, + Json(gen_req.into()), + ) + .await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(CompletionResponse::from(generation.0))).into_response()) } @@ -178,11 +213,13 @@ example = json ! ({"error": "Incomplete generation"})), async fn chat_completions_v1( default_return_full_text: Extension, infer: Extension, + info: Extension, + request_logger_sender: Extension>>, req_headers: HeaderMap, req: Json, ) -> Result)> { let mut req = req.0; - if req.model == MODEL_ID.get().unwrap().as_str() { + if req.model == info.model_id.as_str() { // Allow user to specify the base model, but treat it as an empty adapter_id tracing::info!("Replacing base model {0} with empty adapter_id", req.model); req.model = "".to_string(); @@ -208,11 +245,25 @@ async fn chat_completions_v1( ) }; - let (headers, stream) = - generate_stream_with_callback(infer, req_headers, Json(gen_req.into()), callback).await; + let (headers, stream) = generate_stream_with_callback( + infer, + info, + request_logger_sender, + req_headers, + Json(gen_req.into()), + callback, + ) + .await; Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) } else { - let (headers, generation) = generate(infer, req_headers, Json(gen_req.into())).await?; + let (headers, generation) = generate( + infer, + info, + request_logger_sender, + req_headers, + Json(gen_req.into()), + ) + .await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(ChatCompletionResponse::from(generation.0))).into_response()) } @@ -287,6 +338,8 @@ seed, )] async fn generate( infer: Extension, + info: Extension, + request_logger_sender: Extension>>, req_headers: HeaderMap, mut req: Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { @@ -320,6 +373,8 @@ async fn generate( }); } + let api_token = req.parameters.api_token.clone(); + // Inference let (response, best_of_responses) = match req.0.parameters.best_of { Some(best_of) if best_of > 1 => { @@ -430,7 +485,7 @@ async fn generate( time_per_token.as_millis().to_string().parse().unwrap(), ); - headers.insert("x-model-id", MODEL_ID.get().unwrap().parse().unwrap()); + headers.insert("x-model-id", info.model_id.parse().unwrap()); let adapter_id_string = adapter_parameters .adapter_ids @@ -467,6 +522,16 @@ async fn generate( response.generated_text.generated_tokens as f64 ); + if info.request_logger_url.is_some() { + let _ = request_logger_sender + .send(( + response.generated_text.generated_tokens as i64, + api_token.unwrap_or("".to_string()), + info.model_id.clone(), + )) + .await; + } + // Send response let mut output_text = response.generated_text.text; if let Some(prompt) = add_prompt { @@ -520,6 +585,8 @@ seed, )] async fn generate_stream( infer: Extension, + info: Extension, + request_logger_sender: Extension>>, req_headers: HeaderMap, mut req: Json, ) -> ( @@ -527,12 +594,22 @@ async fn generate_stream( Sse>>, ) { let callback = |resp: StreamResponse| Event::default().json_data(resp).unwrap(); - let (headers, stream) = generate_stream_with_callback(infer, req_headers, req, callback).await; + let (headers, stream) = generate_stream_with_callback( + infer, + info, + request_logger_sender, + req_headers, + req, + callback, + ) + .await; (headers, Sse::new(stream).keep_alive(KeepAlive::default())) } async fn generate_stream_with_callback( infer: Extension, + info: Extension, + request_logger_sender: Extension>>, req_headers: HeaderMap, mut req: Json, callback: impl Fn(StreamResponse) -> Event, @@ -564,6 +641,8 @@ async fn generate_stream_with_callback( }); } + let api_token = req.parameters.api_token.clone(); + let (adapter_source, adapter_parameters) = extract_adapter_params( req.0.parameters.adapter_id.clone(), req.0.parameters.adapter_source.clone(), @@ -584,7 +663,7 @@ async fn generate_stream_with_callback( headers.insert("x-adapter-source", adapter_source.unwrap().parse().unwrap()); } - headers.insert("x-model-id", MODEL_ID.get().unwrap().parse().unwrap()); + headers.insert("x-model-id", info.model_id.parse().unwrap()); let stream = async_stream::stream! { // Inference @@ -681,6 +760,8 @@ async fn generate_stream_with_callback( metrics::histogram!("lorax_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); metrics::histogram!("lorax_request_generated_tokens", generated_text.generated_tokens as f64); + + // StreamResponse end_reached = true; @@ -692,6 +773,10 @@ async fn generate_stream_with_callback( tracing::debug!(parent: &span, "Output: {}", output_text); tracing::info!(parent: &span, "Success"); + if info.request_logger_url.is_some() { + let _ = request_logger_sender.send((generated_text.generated_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; + } + let stream_token = StreamResponse { token, generated_text: Some(output_text), @@ -743,6 +828,36 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +async fn request_logger( + request_logger_url: Option, + mut rx: mpsc::Receiver<(i64, String, String)>, +) { + if request_logger_url.is_none() { + tracing::info!("REQUEST_LOGGER_URL not set, request logging is disabled"); + return; + } + + let url_string = request_logger_url.unwrap(); + tracing::info!("Request logging enabled, sending logs to {url_string}"); + + let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); + let client = ClientBuilder::new(reqwest::Client::new()) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build(); + while let Some((tokens, api_token, model_id)) = rx.recv().await { + // Make a request out to localhost:8899 with the tokens, api_token, and model_id + let res = client + .post(&url_string) + .json(&json!({"tokens": tokens, "api_token": api_token, "model_id": model_id})) + .send() + .await; + + if let Err(e) = res { + tracing::error!("Failed to log request: {e}"); + } + } +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -841,8 +956,6 @@ pub async fn run( generation_health, ); - let model_id = model_info.model_id.clone(); - // Duration buckets let duration_matcher = Matcher::Suffix(String::from("duration")); let n_duration_buckets = 35; @@ -935,17 +1048,24 @@ pub async fn run( version: env!("CARGO_PKG_VERSION"), sha: option_env!("VERGEN_GIT_SHA"), docker_label: option_env!("DOCKER_LABEL"), + request_logger_url: std::env::var("REQUEST_LOGGER_URL").ok(), }; - MODEL_ID.set(model_id.clone()).unwrap_or_else(|_| { - panic!("MODEL_ID was already set!"); - }); DEFAULT_ADAPTER_SOURCE .set(adapter_source.clone()) .unwrap_or_else(|_| { panic!("DEFAULT_ADAPTER_SOURCE was already set!"); }); + // Kick off thread here that writes to the log file + let (tx, rx) = mpsc::channel(32); + let request_logger_sender = Arc::new(tx); + if info.request_logger_url.is_some() { + tokio::spawn(request_logger(info.request_logger_url.clone(), rx)); + } else { + tracing::info!("REQUEST_LOGGER_URL not set, request logging is disabled"); + } + // Create router let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) @@ -967,6 +1087,7 @@ pub async fn run( // Prometheus metrics route .route("/metrics", get(metrics)) .layer(Extension(info)) + .layer(Extension(request_logger_sender.clone())) .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer))