diff --git a/src/lib.rs b/src/lib.rs index 2ea0aa31b..f24d5211f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod logger; +pub mod middleware; pub mod routes; pub mod server; pub mod state; diff --git a/src/main.rs b/src/main.rs index 4a5104783..c463d9312 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use log::info; mod config; mod logger; +mod middleware; mod routes; mod server; mod state; diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 000000000..bb57931af --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,77 @@ +use axum::{ + body::Body, + extract::Request, + http::StatusCode, + middleware::Next, + response::{IntoResponse, Response}, +}; +use tracing::{enabled, trace, Level}; +use core::str; +use http_body_util::BodyExt; +use serde_json::Value; + +pub async fn body_logger_middleware( + request: Request, + next: Next, +) -> Result { + if !enabled!(target: "live_compositor::log_request_body", Level::TRACE) { + return Ok(next.run(request).await); + } + let request = buffer_request_body(request).await?; + let response = next.run(request).await; + let response = buffer_response_body(response).await?; + + Ok(response) +} + +async fn buffer_request_body(request: Request) -> Result { + let (parts, body) = request.into_parts(); + + let bytes = body + .collect() + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())? + .to_bytes(); + + match serde_json::from_slice::(&bytes) { + Ok(body_json) => { + trace!(target: "live_compositor::log_request_body", method = ?parts.method, path = ?parts.uri, "Request body: {}", body_json); + } + Err(_) => match str::from_utf8(&bytes) { + Ok(body_str) => { + trace!(target: "live_compositor::log_request_body", method = ?parts.method, path = ?parts.uri, "Request body: {}", body_str); + } + Err(_) => { + trace!(target: "live_compositor::log_request_body", method = ?parts.method, path = ?parts.uri, "Request body: {:?}", bytes); + } + } + } + + Ok(Request::from_parts(parts, Body::from(bytes))) +} + +async fn buffer_response_body(response: Response) -> Result { + let (parts, body) = response.into_parts(); + + let bytes = body + .collect() + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())? + .to_bytes(); + + match serde_json::from_slice::(&bytes) { + Ok(body_json) => { + trace!(target: "live_compositor::log_request_body", status=?parts.status, "Response body: {}", body_json); + } + Err(_) => match str::from_utf8(&bytes) { + Ok(body_str) => { + trace!(target: "live_compositor::log_request_body", status=?parts.status, "Response body: {}", body_str); + } + Err(_) => { + trace!(target: "live_compositor::log_request_body", status=?parts.status, "Response body: {:?}", bytes); + } + }, + } + + Ok(Response::from_parts(parts, Body::from(bytes))) +} diff --git a/src/routes.rs b/src/routes.rs index 500bb8a90..6a41f3568 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,20 +1,16 @@ use axum::{ async_trait, - body::Body, extract::{rejection::JsonRejection, ws::WebSocketUpgrade, FromRequest, Request, State}, http::StatusCode, - middleware::{self, Next}, - response::{IntoResponse, Response}, + middleware, + response::IntoResponse, routing::{get, post}, Router, }; use compositor_pipeline::Pipeline; -use core::str; -use http_body_util::BodyExt; -use log::info; use serde_json::{json, Value}; -use crate::state::{self, ApiState}; +use crate::state::{ApiState, Response}; use compositor_api::error::ApiError; @@ -22,6 +18,7 @@ use self::{ update_output::handle_keyframe_request, update_output::handle_output_update, ws::handle_ws_upgrade, }; +use crate::middleware::body_logger_middleware; mod register_request; mod unregister_request; @@ -59,9 +56,9 @@ pub fn routes(state: ApiState) -> Router { .route("/:id/register", post(register_request::handle_shader)) .route("/:id/unregister", post(unregister_request::handle_shader)); - async fn handle_start(State(state): State) -> Result { + async fn handle_start(State(state): State) -> Result { Pipeline::start(&state.pipeline); - Ok(state::Response::Ok {}) + Ok(Response::Ok {}) } Router::new() @@ -80,7 +77,7 @@ pub fn routes(state: ApiState) -> Router { "instance_id": state.config.instance_id }))), ) - .layer(middleware::from_fn(log_request_response_body)) + .layer(middleware::from_fn(body_logger_middleware)) .with_state(state) } @@ -117,42 +114,3 @@ where } } } - -async fn log_request_response_body( - request: Request, - next: Next, -) -> Result { - let request = buffer_request_body(request).await?; - let response = next.run(request).await; - let response = buffer_response_body(response).await?; - - Ok(response) -} - -async fn buffer_request_body(request: Request) -> Result { - let (parts, body) = request.into_parts(); - - let bytes = body - .collect() - .await - .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())? - .to_bytes(); - - info!("Request body: {:?}", str::from_utf8(&bytes).unwrap()); - - Ok(Request::from_parts(parts, Body::from(bytes))) -} - -async fn buffer_response_body(response: Response) -> Result { - let (parts, body) = response.into_parts(); - - let bytes = body - .collect() - .await - .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())? - .to_bytes(); - - info!("Response body: {:?}", str::from_utf8(&bytes).unwrap()); - - Ok(Response::from_parts(parts, Body::from(bytes))) -}