Skip to content

Commit

Permalink
fix: Add Custom Header option not replacing already existing headers #…
Browse files Browse the repository at this point in the history
  • Loading branch information
rholshausen committed May 23, 2023
1 parent 8156751 commit 71b38a8
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 18 deletions.
2 changes: 1 addition & 1 deletion rust/pact_verifier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ exclude = [
anyhow = "1.0.66"
serde = "1.0.147"
serde_json = "1.0.87"
pact_matching = { version = "~1.1.1", path = "../pact_matching" }
pact_matching = { version = "~1.1.0", path = "../pact_matching" }
pact_models = "~1.1.2"
pact-plugin-driver = "~0.4.4"
maplit = "1.0.2"
Expand Down
2 changes: 1 addition & 1 deletion rust/pact_verifier/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ async fn execute_provider_states<S: ProviderStateExecutor>(
}

/// Configure the HTTP client to use for requests to the provider
fn configure_http_client<F: RequestFilterExecutor>(
pub(crate) fn configure_http_client<F: RequestFilterExecutor>(
options: &VerificationOptions<F>
) -> anyhow::Result<Client> {
let mut client_builder = reqwest::Client::builder()
Expand Down
117 changes: 101 additions & 16 deletions rust/pact_verifier/src/provider_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ pub fn join_paths(base: &str, path: &str) -> String {
}
}

fn create_native_request(client: &Client, base_url: &str, request: &HttpRequest) -> Result<RequestBuilder, ProviderClientError> {
fn create_native_request(
client: &Client,
base_url: &str,
request: &HttpRequest,
custom_headers: &HashMap<String, String>
) -> Result<RequestBuilder, ProviderClientError> {
let url = join_paths(base_url, &request.path.clone());
let mut builder = client.request(Method::from_bytes(
&request.method.clone().into_bytes()).unwrap_or(Method::GET), &url);
Expand All @@ -101,13 +106,15 @@ fn create_native_request(client: &Client, base_url: &str, request: &HttpRequest)
if let Some(headers) = &request.headers {
let mut header_map = HeaderMap::new();
for (k, vals) in headers {
for header_value in vals {
let header_name = HeaderName::try_from(k)
.map_err(|err| ProviderClientError::RequestHeaderNameError(
format!("Failed to parse header value: {}", header_value), err))?;
header_map.append(header_name, HeaderValue::from_str(header_value.as_str())
.map_err(|err| ProviderClientError::RequestHeaderValueError(
format!("Failed to parse header value: {}", header_value), err))?);
if !custom_headers.contains_key(k) {
for header_value in vals {
let header_name = HeaderName::try_from(k)
.map_err(|err| ProviderClientError::RequestHeaderNameError(
format!("Failed to parse header value: {}", header_value), err))?;
header_map.append(header_name, HeaderValue::from_str(header_value.as_str())
.map_err(|err| ProviderClientError::RequestHeaderValueError(
format!("Failed to parse header value: {}", header_value), err))?);
}
}
}
builder = builder.headers(header_map);
Expand Down Expand Up @@ -211,7 +218,7 @@ pub async fn make_provider_request<F: RequestFilterExecutor>(
debug!("Provider details = {provider:?}");
info!("Sending request {request}");
debug!("body:\n{}", request.body.display_string());
let request = create_native_request(client, &base_url, &request)?;
let request = create_native_request(client, &base_url, &request, &options.custom_headers)?;

let response = request.send()
.map_err(|err| anyhow!(err))
Expand All @@ -234,7 +241,7 @@ pub async fn make_state_change_request(
) -> anyhow::Result<HashMap<String, Value>> {
debug!("Sending {} to state change handler", request);

let request = create_native_request(client, state_change_url, request)?;
let request = create_native_request(client, state_change_url, request, &hashmap!{})?;
let result = with_retries(retries, request).await;

match result {
Expand Down Expand Up @@ -282,6 +289,16 @@ mod tests {
use pact_models::bodies::OptionalBody;
use pact_models::v4::http_parts::HttpRequest;

use pact_consumer::builders::{HttpPartBuilder, PactBuilderAsync};
use pact_consumer::mock_server::StartMockServer;

use crate::{
configure_http_client,
NullRequestFilterExecutor,
ProviderInfo,
VerificationOptions
};

use super::{create_native_request, extract_headers, join_paths};

#[test]
Expand Down Expand Up @@ -336,7 +353,7 @@ mod tests {
let client = reqwest::Client::new();
let base_url = "http://example.test:8080".to_string();
let request = HttpRequest::default();
let request_builder = create_native_request(&client, &base_url, &request).unwrap().build().unwrap();
let request_builder = create_native_request(&client, &base_url, &request, &hashmap!{}).unwrap().build().unwrap();

expect!(request_builder.method()).to(be_equal_to("GET"));
expect!(request_builder.url().as_str()).to(be_equal_to("http://example.test:8080/"));
Expand All @@ -354,7 +371,7 @@ mod tests {
}),
.. HttpRequest::default()
};
let request_builder = create_native_request(&client, &base_url, &request).unwrap().build().unwrap();
let request_builder = create_native_request(&client, &base_url, &request, &hashmap!{}).unwrap().build().unwrap();

expect!(request_builder.method()).to(be_equal_to("GET"));
expect!(request_builder.url().as_str()).to(be_equal_to("http://example.test:8080/?a=b&c=d&c=e"));
Expand All @@ -371,7 +388,7 @@ mod tests {
}),
.. HttpRequest::default()
};
let request_builder = create_native_request(&client, &base_url, &request).unwrap().build().unwrap();
let request_builder = create_native_request(&client, &base_url, &request, &hashmap!{}).unwrap().build().unwrap();

expect!(request_builder.method()).to(be_equal_to("GET"));
expect!(request_builder.url().as_str()).to(be_equal_to("http://example.test:8080/"));
Expand All @@ -392,7 +409,7 @@ mod tests {
body: OptionalBody::from("body"),
.. HttpRequest::default()
};
let request_builder = create_native_request(&client, &base_url, &request).unwrap().build().unwrap();
let request_builder = create_native_request(&client, &base_url, &request, &hashmap!{}).unwrap().build().unwrap();

expect!(request_builder.method()).to(be_equal_to("GET"));
expect!(request_builder.url().as_str()).to(be_equal_to("http://example.test:8080/"));
Expand All @@ -407,7 +424,7 @@ mod tests {
body: OptionalBody::Null,
.. HttpRequest::default()
};
let request_builder = create_native_request(&client, &base_url, &request).unwrap().build().unwrap();
let request_builder = create_native_request(&client, &base_url, &request, &hashmap!{}).unwrap().build().unwrap();

expect!(request_builder.method()).to(be_equal_to("GET"));
expect!(request_builder.url().as_str()).to(be_equal_to("http://example.test:8080/"));
Expand All @@ -425,10 +442,78 @@ mod tests {
body: OptionalBody::Null,
.. HttpRequest::default()
};
let request_builder = create_native_request(&client, &base_url, &request).unwrap().build().unwrap();
let request_builder = create_native_request(&client, &base_url, &request, &hashmap!{}).unwrap().build().unwrap();

expect!(request_builder.method()).to(be_equal_to("GET"));
expect!(request_builder.url().as_str()).to(be_equal_to("http://example.test:8080/"));
expect!(request_builder.body().unwrap().as_bytes()).to(be_some().value("null".as_bytes()));
}

#[tokio::test]
async fn do_not_overwrite_custom_headers_with_headers_from_pact_file() {
let request = HttpRequest {
headers: Some(hashmap!{
"X-A".to_string() => vec![ "val-a".to_string() ],
"X-B".to_string() => vec![ "val-b".to_string() ],
"X-C".to_string() => vec![ "val-c".to_string() ]
}),
.. HttpRequest::default()
};
let options = VerificationOptions {
custom_headers: hashmap!{
"X-B".to_string() => "other-b".to_string()
},
.. VerificationOptions::<NullRequestFilterExecutor>::default()
};
let client = configure_http_client(&options).unwrap();

let server = PactBuilderAsync::new("make_provider_request", "provider")
.interaction("request with headers to be overwritten", "", |mut i| async move {
i.request
.method("GET")
.header("X-A", "val-a")
.header("X-B", "other-b")
.header("X-C", "val-c");
i.response.ok();
i
})
.await
.start_mock_server(None);

#[allow(deprecated)]
let provider = ProviderInfo {
port: server.url().port(),
.. ProviderInfo::default()
};
super::make_provider_request(&provider, &request, &options, &client, None).await.unwrap();
}

#[test]
fn convert_request_to_native_request_with_custom_headers() {
let client = reqwest::Client::new();
let base_url = "http://example.test:8080".to_string();
let request = HttpRequest {
headers: Some(hashmap! {
"X-A".to_string() => vec![ "val-a".to_string() ],
"X-B".to_string() => vec![ "val-b".to_string() ],
"X-C".to_string() => vec![ "val-c".to_string() ]
}),
.. HttpRequest::default()
};
let custom_headers = hashmap!{
"X-B".to_string() => "other-b".to_string(),
"X-D".to_string() => "val-d".to_string()
};
let request_builder = create_native_request(&client, &base_url, &request, &custom_headers).unwrap().build().unwrap();

let headers = request_builder.headers();
let keys = headers.keys()
.map(|k| k.as_str())
.sorted()
.collect_vec();
expect!(keys).to(be_equal_to(vec![
"x-a",
"x-c"
]));
}
}

0 comments on commit 71b38a8

Please sign in to comment.