Skip to content

Commit

Permalink
Store collection job ID instead of URL (#1797)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Aug 23, 2023
1 parent 32cd883 commit 451df1f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 65 deletions.
114 changes: 61 additions & 53 deletions collector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,8 @@ struct CollectionJob<P, Q>
where
Q: QueryType,
{
/// The URL provided by the leader aggregator, where the collect response will be available
/// upon completion.
#[derivative(Debug(format_with = "std::fmt::Display::fmt"))]
collection_job_url: Url,
/// The collection job ID.
collection_job_id: CollectionJobId,
/// The collect request's query.
query: Query<Q>,
/// The aggregation parameter used in this collect request.
Expand All @@ -254,12 +252,12 @@ where

impl<P, Q: QueryType> CollectionJob<P, Q> {
fn new(
collection_job_url: Url,
collection_job_id: CollectionJobId,
query: Query<Q>,
aggregation_parameter: P,
) -> CollectionJob<P, Q> {
CollectionJob {
collection_job_url,
collection_job_id,
query,
aggregation_parameter,
}
Expand Down Expand Up @@ -395,7 +393,8 @@ impl<V: vdaf::Collector> Collector<V> {
) -> Result<CollectionJob<V::AggregationParam, Q>, Error> {
let collect_request =
CollectionReq::new(query.clone(), aggregation_parameter.get_encoded());
let collection_job_url = self.parameters.collection_job_uri(random())?;
let collection_job_id = random();
let collection_job_url = self.parameters.collection_job_uri(collection_job_id)?;

let response_res = retry_http_request(
self.parameters.http_request_retry_parameters.clone(),
Expand Down Expand Up @@ -434,7 +433,7 @@ impl<V: vdaf::Collector> Collector<V> {
};

Ok(CollectionJob::new(
collection_job_url,
collection_job_id,
query,
aggregation_parameter.clone(),
))
Expand All @@ -447,13 +446,14 @@ impl<V: vdaf::Collector> Collector<V> {
&self,
job: &CollectionJob<V::AggregationParam, Q>,
) -> Result<PollResult<V::AggregateResult, Q>, Error> {
let collection_job_url = self.parameters.collection_job_uri(job.collection_job_id)?;
let response_res = retry_http_request(
self.parameters.http_request_retry_parameters.clone(),
|| async {
let (auth_header, auth_value) =
self.parameters.authentication.request_authentication();
self.http_client
.post(job.collection_job_url.clone())
.post(collection_job_url.clone())
.header(auth_header, auth_value)
.send()
.await
Expand Down Expand Up @@ -634,29 +634,6 @@ impl<V: vdaf::Collector> Collector<V> {
}
}

#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_util {
use crate::{Collection, Collector, Error};
use janus_messages::{query_type::QueryType, Query};
use prio::vdaf;

pub async fn collect_with_rewritten_url<V: vdaf::Collector, Q: QueryType>(
collector: &Collector<V>,
query: Query<Q>,
aggregation_parameter: &V::AggregationParam,
host: &str,
port: u16,
) -> Result<Collection<V::AggregateResult, Q>, Error> {
let mut job = collector
.start_collection(query, aggregation_parameter)
.await?;
job.collection_job_url.set_host(Some(host))?;
job.collection_job_url.set_port(Some(port)).unwrap();
collector.poll_until_complete(&job).await
}
}

#[cfg(test)]
mod tests {
use crate::{
Expand Down Expand Up @@ -880,20 +857,24 @@ mod tests {
let job = job.unwrap();
assert_eq!(job.query.batch_interval(), &batch_interval);

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mocked_collect_error = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(500)
.expect(1)
.create_async()
.await;
let mocked_collect_accepted = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(202)
.expect(2)
.create_async()
.await;
let mocked_collect_complete = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.match_header(auth_header, auth_value.as_str())
.with_status(200)
.with_header(
Expand Down Expand Up @@ -965,8 +946,12 @@ mod tests {
assert_eq!(job.query.batch_interval(), &batch_interval);
mocked_collect_start_success.assert_async().await;

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mocked_collect_complete = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand Down Expand Up @@ -1033,8 +1018,12 @@ mod tests {

mocked_collect_start_success.assert_async().await;

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mocked_collect_complete = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand Down Expand Up @@ -1110,8 +1099,12 @@ mod tests {

mocked_collect_start_success.assert_async().await;

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mocked_collect_complete = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand Down Expand Up @@ -1180,8 +1173,12 @@ mod tests {

mocked_collect_start_success.assert_async().await;

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mocked_collect_complete = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand Down Expand Up @@ -1260,8 +1257,12 @@ mod tests {
let job = job.unwrap();
assert_eq!(job.query.batch_interval(), &batch_interval);

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mocked_collect_complete = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.match_header(AUTHORIZATION.as_str(), "Bearer AAAAAAAAAAAAAAAA")
.with_status(200)
.with_header(
Expand Down Expand Up @@ -1427,8 +1428,12 @@ mod tests {
mock_collect_start.assert_async().await;
mock_collection_job_server_error.assert_async().await;

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mock_collection_job_server_error_details = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(500)
.with_header("Content-Type", "application/problem+json")
.with_body("{\"type\": \"http://example.com/test_server_error\"}")
Expand All @@ -1448,7 +1453,7 @@ mod tests {
.await;

let mock_collection_job_bad_request = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(400)
.with_header("Content-Type", "application/problem+json")
.with_body(concat!(
Expand All @@ -1471,7 +1476,7 @@ mod tests {
mock_collection_job_bad_request.assert_async().await;

let mock_collection_job_bad_message_bytes = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand All @@ -1488,7 +1493,7 @@ mod tests {
mock_collection_job_bad_message_bytes.assert_async().await;

let mock_collection_job_bad_share_count = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand All @@ -1513,7 +1518,7 @@ mod tests {
mock_collection_job_bad_share_count.assert_async().await;

let mock_collection_job_bad_ciphertext = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand Down Expand Up @@ -1582,7 +1587,7 @@ mod tests {
]),
);
let mock_collection_job_bad_shares = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand Down Expand Up @@ -1633,7 +1638,7 @@ mod tests {
]),
);
let mock_collection_job_wrong_length = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(200)
.with_header(
CONTENT_TYPE.as_str(),
Expand All @@ -1650,7 +1655,7 @@ mod tests {
mock_collection_job_wrong_length.assert_async().await;

let mock_collection_job_always_fail = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(500)
.expect_at_least(3)
.create_async()
Expand Down Expand Up @@ -1692,8 +1697,12 @@ mod tests {
.unwrap();
mock_collect_start.assert_async().await;

let collection_job_path = format!(
"/tasks/{}/collection_jobs/{}",
collector.parameters.task_id, job.collection_job_id
);
let mock_collect_poll_no_retry_after = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(202)
.expect(1)
.create_async()
Expand All @@ -1705,7 +1714,7 @@ mod tests {
mock_collect_poll_no_retry_after.assert_async().await;

let mock_collect_poll_retry_after_60s = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(202)
.with_header("Retry-After", "60")
.expect(1)
Expand All @@ -1718,7 +1727,7 @@ mod tests {
mock_collect_poll_retry_after_60s.assert_async().await;

let mock_collect_poll_retry_after_date_time = server
.mock("POST", job.collection_job_url.path())
.mock("POST", collection_job_path.as_str())
.with_status(202)
.with_header("Retry-After", "Wed, 21 Oct 2015 07:28:00 GMT")
.expect(1)
Expand Down Expand Up @@ -1752,14 +1761,13 @@ mod tests {
collector.parameters.task_id
);

let collection_job_url = format!("{}{collection_job_path}", server.url());
let batch_interval = Interval::new(
Time::from_seconds_since_epoch(1_000_000),
Duration::from_seconds(3600),
)
.unwrap();
let job = CollectionJob::new(
collection_job_url.parse().unwrap(),
collection_job_id,
Query::new_time_interval(batch_interval),
(),
);
Expand Down
14 changes: 2 additions & 12 deletions integration_tests/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use backoff::{future::retry, ExponentialBackoffBuilder};
use itertools::Itertools;
use janus_aggregator_core::task::{test_util::TaskBuilder, QueryType};
use janus_collector::{
test_util::collect_with_rewritten_url, Collection, Collector, CollectorParameters,
};
use janus_collector::{Collection, Collector, CollectorParameters};
use janus_core::{
hpke::test_util::generate_test_hpke_config_and_private_key,
retries::test_http_request_exponential_backoff,
Expand Down Expand Up @@ -76,8 +74,6 @@ pub async fn collect_generic<'a, V, Q>(
collector: &Collector<V>,
query: Query<Q>,
aggregation_parameter: &V::AggregationParam,
host: &str,
port: u16,
) -> Result<Collection<V::AggregateResult, Q>, janus_collector::Error>
where
V: vdaf::Client<16> + vdaf::Collector + InteropClientEncoding,
Expand All @@ -94,9 +90,7 @@ where
retry(backoff, || {
let query = query.clone();
async move {
match collect_with_rewritten_url(collector, query, aggregation_parameter, host, port)
.await
{
match collector.collect(query, aggregation_parameter).await {
Ok(collection) => Ok(collection),
Err(
error @ janus_collector::Error::Http {
Expand Down Expand Up @@ -168,8 +162,6 @@ pub async fn submit_measurements_and_verify_aggregate_generic<V>(
&collector,
Query::new_time_interval(batch_interval),
&test_case.aggregation_parameter,
"127.0.0.1",
leader_port,
)
.await
.unwrap();
Expand All @@ -188,8 +180,6 @@ pub async fn submit_measurements_and_verify_aggregate_generic<V>(
&collector,
Query::new_fixed_size(FixedSizeQuery::CurrentBatch),
&test_case.aggregation_parameter,
"127.0.0.1",
leader_port,
)
.await;
match collection_res {
Expand Down

0 comments on commit 451df1f

Please sign in to comment.