Skip to content

Commit

Permalink
collector: Set Content-Length when polling collection jobs (#2999)
Browse files Browse the repository at this point in the history
  • Loading branch information
inahga authored Apr 10, 2024
1 parent d423372 commit 6af1dea
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions collector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use prio::{
};
use rand::random;
use reqwest::{
header::{HeaderValue, ToStrError, CONTENT_TYPE, RETRY_AFTER},
header::{HeaderValue, ToStrError, CONTENT_LENGTH, CONTENT_TYPE, RETRY_AFTER},
StatusCode,
};
pub use retry_after;
Expand Down Expand Up @@ -528,6 +528,9 @@ impl<V: vdaf::Collector> Collector<V> {
let (auth_header, auth_value) = self.authentication.request_authentication();
self.http_client
.post(collection_job_url.clone())
// reqwest does not send Content-Length for requests with empty bodies. Some
// HTTP servers require this anyway, so explicitly set it.
.header(CONTENT_LENGTH, 0)
.header(auth_header, auth_value)
.send()
.await
Expand Down Expand Up @@ -771,7 +774,7 @@ mod tests {
};
use rand::random;
use reqwest::{
header::{AUTHORIZATION, CONTENT_TYPE},
header::{AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE},
StatusCode, Url,
};
use retry_after::RetryAfter;
Expand Down Expand Up @@ -1942,4 +1945,67 @@ mod tests {

mock_error.assert_async().await;
}

#[tokio::test]
async fn poll_content_length_header() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let vdaf = Prio3::new_count(2).unwrap();
let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &0);
let collector = setup_collector(&mut server, vdaf);
let (auth_header, auth_value) = collector.authentication.request_authentication();

let batch_interval = Interval::new(
Time::from_seconds_since_epoch(1_000_000),
Duration::from_seconds(3600),
)
.unwrap();
let collect_resp =
build_collect_response_time(&transcript, &collector, &(), batch_interval);

let job = CollectionJob {
collection_job_id: random(),
query: Query::new_time_interval(batch_interval),
aggregation_parameter: (),
};

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.task_id, job.collection_job_id
);
let mocked_collect_error = server
.mock("POST", collection_job_path.as_str())
.with_status(500)
.expect(1)
.create_async()
.await;
let mocked_collect_accepted = server
.mock("POST", collection_job_path.as_str())
.match_header(CONTENT_LENGTH.as_str(), "0")
.with_status(202)
.expect(2)
.create_async()
.await;
let mocked_collect_complete = server
.mock("POST", collection_job_path.as_str())
.match_header(auth_header, auth_value.as_str())
.match_header(CONTENT_LENGTH.as_str(), "0")
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
CollectionMessage::<TimeInterval>::MEDIA_TYPE,
)
.with_body(collect_resp.get_encoded())
.expect(1)
.create_async()
.await;

let poll_result = collector.poll_once(&job).await.unwrap();
assert_matches!(poll_result, PollResult::NotReady(None));

collector.poll_until_complete(&job).await.unwrap();
mocked_collect_error.assert_async().await;
mocked_collect_accepted.assert_async().await;
mocked_collect_complete.assert_async().await;
}
}

0 comments on commit 6af1dea

Please sign in to comment.