From 795abff4a99de31a140949d301b00e454c6eda7c Mon Sep 17 00:00:00 2001 From: fbrv Date: Mon, 22 Jul 2024 08:59:54 +0100 Subject: [PATCH] forward headers --- src/forward_service.rs | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/forward_service.rs b/src/forward_service.rs index 21b99b6..92e254a 100644 --- a/src/forward_service.rs +++ b/src/forward_service.rs @@ -5,6 +5,7 @@ use std::{ use axum::{ extract::{Path, State}, + http::HeaderMap, response::IntoResponse, routing::post, Router, @@ -102,11 +103,12 @@ fn router(shared_state: SharedState) -> Router { async fn scan_id_forward_request( State(state): State>, Path(chain_id): Path, + headers: HeaderMap, body: Bytes, ) -> Result { if let Some(manager) = state.managers.get(&chain_id) { if let Some(entry) = manager.get_next_elected_preconfer() { - match inner_forward_request(body, &entry.url, &state.client).await { + match inner_forward_request(&state.client, &entry.url, body, headers).await { Ok(res) => Ok(res), Err(_) => Err(( StatusCode::INTERNAL_SERVER_ERROR, @@ -129,11 +131,12 @@ async fn forward_request(State(_state): State>) -> impl IntoRes } async fn inner_forward_request( - bytes: Bytes, - to_addr: &str, client: &ClientWithMiddleware, + to_addr: &str, + bytes: Bytes, + headers: HeaderMap, ) -> Result { - let res = client.post(to_addr).body(bytes).send().await?; + let res = client.post(to_addr).body(bytes).headers(headers).send().await?; let body = res.bytes().await?; Ok(body) } @@ -147,6 +150,7 @@ mod test { use axum::{ extract::State, + http::HeaderMap, response::IntoResponse, routing::{get, post}, Router, @@ -155,7 +159,7 @@ mod test { use dashmap::DashMap; use eyre::Result; use hashbrown::HashMap; - use http::StatusCode; + use http::{HeaderValue, StatusCode}; use crate::{ forward_service::{router, SharedState}, @@ -253,9 +257,13 @@ mod test { }); tokio::time::sleep(Duration::from_secs(1)).await; for _ in 0..10 { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", HeaderValue::from_str("application/json").unwrap()); let res = reqwest::Client::new() .post("http://localhost:12005/1") .body("dummy plain body") + .headers(headers) + .headers(HeaderMap::new()) .send() .await .unwrap(); @@ -270,9 +278,11 @@ mod test { async fn handle_request( State(state): State>>, + headers: HeaderMap, body: Bytes, ) -> impl IntoResponse { assert_eq!("dummy plain body", String::from_utf8(body.into()).unwrap()); + assert_eq!(headers.get("Content-Type").unwrap(), "application/json"); { let mut s = state.lock().unwrap(); s.cnt += 1;