diff --git a/lading/src/blackhole/http.rs b/lading/src/blackhole/http.rs index 51aec797c..6549c2be7 100644 --- a/lading/src/blackhole/http.rs +++ b/lading/src/blackhole/http.rs @@ -16,7 +16,6 @@ use hyper::{ Body, Request, Response, Server, StatusCode, }; use metrics::{register_counter, Counter}; -use once_cell::unsync::OnceCell; use serde::{Deserialize, Serialize}; use tower::ServiceBuilder; use tracing::{debug, error, info}; @@ -25,9 +24,6 @@ use crate::signals::Shutdown; use super::General; -#[allow(clippy::declare_interior_mutable_const)] -const RESPONSE: OnceCell> = OnceCell::new(); - fn default_concurrent_requests_max() -> usize { 100 } @@ -54,6 +50,8 @@ pub enum BodyVariant { Nothing, /// All response bodies will mimic AWS Kinesis. AwsKinesis, + /// Respond with a hardcoded byte slice value + RawBytes, /// Respond with a hardcoded string value Static(String), } @@ -89,6 +87,9 @@ pub struct Config { /// the content-type header to respond with, defaults to 200 #[serde(default = "default_status_code")] pub status: u16, + /// raw array of bytes if the raw_bytes body variant is selected + #[serde(default)] + pub raw_bytes: Vec, } #[derive(Serialize)] @@ -112,7 +113,7 @@ async fn srv( status: StatusCode, bytes_received: Counter, requests_received: Counter, - body_variant: BodyVariant, + body_bytes: Vec, req: Request, headers: HeaderMap, ) -> Result, hyper::Error> { @@ -129,27 +130,7 @@ async fn srv( let mut okay = Response::default(); *okay.status_mut() = status; - *okay.headers_mut() = headers; - - let body_bytes = RESPONSE - .get_or_init(|| match body_variant { - BodyVariant::AwsKinesis => { - let response = KinesisPutRecordBatchResponse { - encrypted: None, - failed_put_count: 0, - request_responses: vec![KinesisPutRecordBatchResponseEntry { - error_code: None, - error_message: None, - record_id: "foobar".to_string(), - }], - }; - serde_json::to_vec(&response).unwrap() - } - BodyVariant::Nothing => vec![], - BodyVariant::Static(val) => val.as_bytes().to_vec(), - }) - .clone(); *okay.body_mut() = Body::from(body_bytes); Ok(okay) } @@ -160,7 +141,7 @@ async fn srv( /// The HTTP blackhole. pub struct Http { httpd_addr: SocketAddr, - body_variant: BodyVariant, + body_bytes: Vec, concurrency_limit: usize, shutdown: Shutdown, headers: HeaderMap, @@ -174,6 +155,10 @@ impl Http { /// # Errors /// /// Returns an error if the configuration is invalid. + /// + /// # Panics + /// + /// None known. pub fn new(general: General, config: &Config, shutdown: Shutdown) -> Result { let status = StatusCode::from_u16(config.status).map_err(Error::InvalidStatusCode)?; @@ -185,9 +170,27 @@ impl Http { metric_labels.push(("id".to_string(), id)); } + let body_bytes = match &config.body_variant { + BodyVariant::AwsKinesis => { + let response = KinesisPutRecordBatchResponse { + encrypted: None, + failed_put_count: 0, + request_responses: vec![KinesisPutRecordBatchResponseEntry { + error_code: None, + error_message: None, + record_id: "foobar".to_string(), + }], + }; + serde_json::to_vec(&response).unwrap() + } + BodyVariant::Nothing => vec![], + BodyVariant::RawBytes => config.raw_bytes.clone(), + BodyVariant::Static(val) => val.as_bytes().to_vec(), + }; + Ok(Self { httpd_addr: config.binding_addr, - body_variant: config.body_variant.clone(), + body_bytes, concurrency_limit: config.concurrent_requests_max, headers: config.headers.clone(), status, @@ -205,17 +208,13 @@ impl Http { /// /// Function will return an error if the configuration is invalid or if /// receiving a packet fails. - /// - /// # Panics - /// - /// None known. pub async fn run(mut self) -> Result<(), Error> { let bytes_received = register_counter!("bytes_received", &self.metric_labels); let requests_received = register_counter!("requests_received", &self.metric_labels); let service = make_service_fn(|_: &AddrStream| { let bytes_received = bytes_received.clone(); let requests_received = requests_received.clone(); - let body_variant = self.body_variant.clone(); + let body_bytes = self.body_bytes.clone(); let headers = self.headers.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |request| { @@ -224,7 +223,7 @@ impl Http { self.status, bytes_received.clone(), requests_received.clone(), - body_variant.clone(), + body_bytes.clone(), request, headers.clone(), ) @@ -259,3 +258,51 @@ impl Http { } } } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::*; + + #[test] + fn config_deserializes_variant_nothing() { + let contents = r#" +binding_addr: "127.0.0.1:1000" +body_variant: "nothing" +"#; + let config: Config = serde_yaml::from_str(contents).unwrap(); + assert_eq!( + config, + Config { + concurrent_requests_max: default_concurrent_requests_max(), + binding_addr: SocketAddr::from_str("127.0.0.1:1000").unwrap(), + body_variant: BodyVariant::Nothing, + headers: default_headers(), + status: default_status_code(), + raw_bytes: vec![], + }, + ); + } + + #[test] + fn config_deserializes_raw_bytes() { + let contents = r#" +binding_addr: "127.0.0.1:1000" +body_variant: "raw_bytes" +raw_bytes: [0x01, 0x02, 0x10] +"#; + let config: Config = serde_yaml::from_str(contents).unwrap(); + assert_eq!( + config, + Config { + concurrent_requests_max: default_concurrent_requests_max(), + binding_addr: SocketAddr::from_str("127.0.0.1:1000").unwrap(), + body_variant: BodyVariant::RawBytes, + headers: default_headers(), + status: default_status_code(), + raw_bytes: vec![0x01, 0x02, 0x10], + }, + ); + } +}