Skip to content
This repository has been archived by the owner on Feb 14, 2025. It is now read-only.

forward headers #3

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/forward_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{

use axum::{
extract::{Path, State},
http::HeaderMap,
response::IntoResponse,
routing::post,
Router,
Expand Down Expand Up @@ -102,11 +103,12 @@ fn router(shared_state: SharedState) -> Router {
async fn scan_id_forward_request(
State(state): State<Arc<SharedState>>,
Path(chain_id): Path<u16>,
headers: HeaderMap,
body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> {
if let Some(manager) = state.managers.get(&chain_id) {
if let Some(entry) = manager.get_next_elected_preconfer() {
match inner_forward_request(body, &entry.url, &state.client).await {
match inner_forward_request(&state.client, &entry.url, body, headers).await {
Ok(res) => Ok(res),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Expand All @@ -129,11 +131,12 @@ async fn forward_request(State(_state): State<Arc<SharedState>>) -> impl IntoRes
}

async fn inner_forward_request(
bytes: Bytes,
to_addr: &str,
client: &ClientWithMiddleware,
to_addr: &str,
bytes: Bytes,
headers: HeaderMap,
) -> Result<impl IntoResponse> {
let res = client.post(to_addr).body(bytes).send().await?;
let res = client.post(to_addr).body(bytes).headers(headers).send().await?;
let body = res.bytes().await?;
Ok(body)
}
Expand All @@ -147,6 +150,7 @@ mod test {

use axum::{
extract::State,
http::HeaderMap,
response::IntoResponse,
routing::{get, post},
Router,
Expand All @@ -155,7 +159,7 @@ mod test {
use dashmap::DashMap;
use eyre::Result;
use hashbrown::HashMap;
use http::StatusCode;
use http::{HeaderValue, StatusCode};

use crate::{
forward_service::{router, SharedState},
Expand Down Expand Up @@ -253,9 +257,13 @@ mod test {
});
tokio::time::sleep(Duration::from_secs(1)).await;
for _ in 0..10 {
let mut headers = HeaderMap::new();
headers.insert("Content-Type", HeaderValue::from_str("application/json").unwrap());
let res = reqwest::Client::new()
.post("http://localhost:12005/1")
.body("dummy plain body")
.headers(headers)
.headers(HeaderMap::new())
.send()
.await
.unwrap();
Expand All @@ -270,9 +278,11 @@ mod test {

async fn handle_request(
State(state): State<Arc<Mutex<DummySharedState>>>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
assert_eq!("dummy plain body", String::from_utf8(body.into()).unwrap());
assert_eq!(headers.get("Content-Type").unwrap(), "application/json");
{
let mut s = state.lock().unwrap();
s.cnt += 1;
Expand Down
Loading