diff --git a/Cargo.lock b/Cargo.lock index 6ccb13a16..9e30916ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3022,6 +3022,7 @@ dependencies = [ "base64 0.22.1", "chrono", "clap", + "derivative", "divviup-client", "futures", "hex", @@ -3036,8 +3037,11 @@ dependencies = [ "janus_messages 0.7.28", "k8s-openapi", "kube", + "opentelemetry", "prio", + "quickcheck", "rand", + "regex", "reqwest", "rstest", "serde", @@ -3045,6 +3049,9 @@ dependencies = [ "tempfile", "testcontainers", "tokio", + "tracing", + "trillium", + "trillium-macros 0.0.6", "trillium-rustls", "trillium-tokio", "url", diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index c7d96a985..2a655214f 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -68,7 +68,7 @@ use trillium_tokio::{CloneCounterObserver, Stopper}; pub struct AggregationJobCreator { // Dependencies. - datastore: Datastore, + datastore: Arc>, meter: Meter, // Configuration values. @@ -92,7 +92,7 @@ pub struct AggregationJobCreator { impl AggregationJobCreator { pub fn new( - datastore: Datastore, + datastore: Arc>, meter: Meter, batch_aggregation_shard_count: u64, tasks_update_frequency: Duration, @@ -298,7 +298,7 @@ impl AggregationJobCreator { fields(task_id = ?task.id()), err )] - async fn create_aggregation_jobs_for_task( + pub async fn create_aggregation_jobs_for_task( self: Arc, task: Arc, ) -> anyhow::Result { @@ -1034,7 +1034,7 @@ mod tests { // kill it. const AGGREGATION_JOB_CREATION_INTERVAL: Duration = Duration::from_secs(1); let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -1215,7 +1215,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -1400,7 +1400,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -1626,7 +1626,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), noop_meter(), 1, Duration::from_secs(3600), @@ -1793,7 +1793,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -2010,7 +2010,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), meter, BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -2179,7 +2179,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), meter, BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -2443,7 +2443,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), meter, BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -2738,7 +2738,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), meter, BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -3030,7 +3030,7 @@ mod tests { // Run. let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), @@ -3235,7 +3235,7 @@ mod tests { .unwrap(); let job_creator = Arc::new(AggregationJobCreator::new( - ds, + Arc::new(ds), noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, Duration::from_secs(3600), diff --git a/aggregator/src/binaries/aggregation_job_creator.rs b/aggregator/src/binaries/aggregation_job_creator.rs index 5b944f65b..e95ce04aa 100644 --- a/aggregator/src/binaries/aggregation_job_creator.rs +++ b/aggregator/src/binaries/aggregation_job_creator.rs @@ -14,7 +14,7 @@ use tracing::info; pub async fn main_callback(ctx: BinaryContext) -> Result<()> { // Start creating aggregation jobs. let aggregation_job_creator = Arc::new(AggregationJobCreator::new( - ctx.datastore, + Arc::new(ctx.datastore), ctx.meter, ctx.config.batch_aggregation_shard_count, Duration::from_secs(ctx.config.tasks_update_frequency_s), diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index b9def05db..2a704a179 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -46,8 +46,15 @@ uuid.workspace = true [dev-dependencies] chrono.workspace = true +derivative.workspace = true divviup-client = { workspace = true, features = ["admin"] } janus_collector = { workspace = true, features = ["test-util"] } +opentelemetry.workspace = true +quickcheck.workspace = true +regex.workspace = true rstest.workspace = true tempfile = { workspace = true } +tracing.workspace = true +trillium.workspace = true +trillium-macros.workspace = true trillium-rustls.workspace = true diff --git a/integration_tests/tests/integration/main.rs b/integration_tests/tests/integration/main.rs index 0b17e2e74..db303e1d8 100644 --- a/integration_tests/tests/integration/main.rs +++ b/integration_tests/tests/integration/main.rs @@ -3,6 +3,7 @@ mod daphne; mod divviup_ts; mod in_cluster; mod janus; +mod simulation; fn initialize_rustls() { // Choose aws-lc-rs as the default rustls crypto provider. This is what's currently enabled by diff --git a/integration_tests/tests/integration/simulation/arbitrary.rs b/integration_tests/tests/integration/simulation/arbitrary.rs new file mode 100644 index 000000000..062520953 --- /dev/null +++ b/integration_tests/tests/integration/simulation/arbitrary.rs @@ -0,0 +1,504 @@ +//! `Arbitrary` implementations to generate simulation inputs. + +// TODO: There will be perennial opportunities to make the distribution of inputs more "realistic" +// and/or "interesting", to control more configuration from simulation inputs, and to introduce new +// forms of fault injection. See https://users.cs.utah.edu/~regehr/papers/swarm12.pdf for example. + +use janus_core::time::TimeExt; +use janus_messages::{CollectionJobId, Duration, Interval, Time}; +use quickcheck::{empty_shrinker, Arbitrary, Gen}; +use rand::random; + +use crate::simulation::{ + model::{Config, Input, Op}, + START_TIME, +}; + +impl Arbitrary for Config { + fn arbitrary(g: &mut Gen) -> Self { + let mut batch_size_limits = [u8::arbitrary(g), u8::arbitrary(g)]; + batch_size_limits.sort(); + + let mut aggregation_job_size_limits = [u8::arbitrary(g), u8::arbitrary(g)]; + aggregation_job_size_limits.sort(); + + Self { + time_precision: Duration::from_seconds(3600), + min_batch_size: batch_size_limits[0].into(), + max_batch_size: bool::arbitrary(g).then_some(batch_size_limits[1].into()), + batch_time_window_size: bool::arbitrary(g) + .then_some(Duration::from_seconds(u8::arbitrary(g).into())), + report_expiry_age: bool::arbitrary(g) + .then_some(Duration::from_seconds(u16::arbitrary(g).into())), + min_aggregation_job_size: aggregation_job_size_limits[0].into(), + max_aggregation_job_size: aggregation_job_size_limits[1].into(), + } + } + + fn shrink(&self) -> Box> { + let mut choices = Vec::with_capacity(3); + if self.max_batch_size.is_some() { + choices.push(Self { + max_batch_size: None, + ..self.clone() + }); + } + if self.batch_time_window_size.is_some() { + choices.push(Self { + batch_time_window_size: None, + ..self.clone() + }); + } + if self.report_expiry_age.is_some() { + choices.push(Self { + report_expiry_age: None, + ..self.clone() + }); + } + Box::new(choices.into_iter()) + } +} + +#[derive(Debug, Clone)] +pub(super) struct TimeIntervalInput(pub(super) Input); + +#[derive(Debug, Clone)] +pub(super) struct FixedSizeInput(pub(super) Input); + +#[derive(Debug, Clone)] +pub(super) struct TimeIntervalFaultInjectionInput(pub(super) Input); + +#[derive(Debug, Clone)] +pub(super) struct FixedSizeFaultInjectionInput(pub(super) Input); + +/// This models the effect that the operations generated so far have on the simulation, and it is +/// used when generating subsequent operations. +struct Context { + current_time: Time, + time_precision: Duration, + started_collection_job_ids: Vec, + polled_collection_job_ids: Vec, +} + +impl Context { + fn new(config: &Config) -> Self { + Self { + current_time: START_TIME, + time_precision: config.time_precision, + started_collection_job_ids: Vec::new(), + polled_collection_job_ids: Vec::new(), + } + } + + fn update(&mut self, op: &Op) { + match op { + Op::AdvanceTime { amount } => { + self.current_time = self.current_time.add(amount).unwrap() + } + Op::CollectorStart { + collection_job_id, + query: _, + } => self.started_collection_job_ids.push(*collection_job_id), + Op::CollectorPoll { collection_job_id } => { + if !self.polled_collection_job_ids.contains(collection_job_id) { + self.polled_collection_job_ids.push(*collection_job_id); + } + } + _ => {} + } + } +} + +/// This is based on `impl Arbitrary for Vec`, but allows passing additional +/// context, and switching between multiple functions to generate an `Op`. +fn arbitrary_vec_with_context( + f: impl Fn(&mut Gen, &Context, &[OpKind]) -> Op, + g: &mut Gen, + mut context: Context, + choices: &[OpKind], +) -> Vec { + let vec_size = Vec::<()>::arbitrary(g).len(); + let mut output = Vec::with_capacity(vec_size); + for _ in 0..vec_size { + let new_op = f(g, &mut context, choices); + context.update(&new_op); + output.push(new_op); + } + output +} + +/// Shrink a vector of operations. +/// +/// Since `Op` doesn't implement `Arbitrary` itself, we first wrap them in a newtype that does, and +/// then dispatch to the blanket shrinking implementation for `Vec`. This will +/// shrink the input list of operations by removing operations, but not otherwise alter any +/// individual operations. +fn shrink_ops(ops: &[Op]) -> Box>> { + #[derive(Clone)] + struct Opaque(T); + + impl Arbitrary for Opaque + where + T: Clone + 'static, + { + fn arbitrary(_g: &mut Gen) -> Self { + unimplemented!() + } + + fn shrink(&self) -> Box> { + empty_shrinker() + } + } + + Box::new( + ops.iter() + .map(|op| Opaque(op.clone())) + .collect::>>() + .shrink() + .map(|ops| ops.iter().map(|wrapped| wrapped.0.clone()).collect()), + ) +} + +/// Generate an upload operation. +fn arbitrary_upload_op(g: &mut Gen, context: &Context) -> Op { + Op::Upload { + report_time: arbitrary_report_time(g, context), + } +} + +/// Generate a replayed upload operation. +fn arbitrary_upload_replay_op(g: &mut Gen, context: &Context) -> Op { + Op::UploadReplay { + report_time: arbitrary_report_time(g, context), + } +} + +/// Generate a random report time for an upload operation. The distribution has extra weight on the +/// current time, because very new or very old reports should be rejected, and thus don't exercise +/// much functionality. +fn arbitrary_report_time(g: &mut Gen, context: &Context) -> Time { + if u8::arbitrary(g) >= 8 { + // now + context.current_time + } else if bool::arbitrary(g) { + // future + context + .current_time + .add(&Duration::from_seconds(u16::arbitrary(g).into())) + .unwrap() + } else { + // past + context + .current_time + .sub(&Duration::from_seconds(u16::arbitrary(g).into())) + .unwrap() + } +} + +/// Generate a collect start operation, using a time interval query. +fn arbitrary_collector_start_op_time_interval(g: &mut Gen, context: &Context) -> Op { + let start_to_now = context.current_time.difference(&START_TIME).unwrap(); + let random_range = start_to_now.as_seconds() / context.time_precision.as_seconds() + 10; + let start = START_TIME + .add(&Duration::from_seconds( + u64::arbitrary(g) % random_range * context.time_precision.as_seconds(), + )) + .unwrap(); + + let duration_fn = g + .choose(&[ + (|_g: &mut Gen, context: &Context| -> Duration { context.time_precision }) + as fn(&mut Gen, &Context) -> Duration, + (|g: &mut Gen, context: &Context| -> Duration { + Duration::from_seconds( + context.time_precision.as_seconds() * (1 + u64::from(u8::arbitrary(g) & 0x1f)), + ) + }) as fn(&mut Gen, &Context) -> Duration, + ]) + .unwrap(); + Op::CollectorStart { + collection_job_id: random(), + query: super::model::Query::TimeInterval( + Interval::new(start, duration_fn(g, context)).unwrap(), + ), + } +} + +fn arbitrary_collector_start_op_fixed_size(g: &mut Gen, context: &Context) -> Op { + if context.polled_collection_job_ids.is_empty() || bool::arbitrary(g) { + Op::CollectorStart { + collection_job_id: random(), + query: super::model::Query::FixedSizeCurrentBatch, + } + } else { + Op::CollectorStart { + collection_job_id: random(), + query: super::model::Query::FixedSizeByBatchId( + *g.choose(&context.polled_collection_job_ids).unwrap(), + ), + } + } +} + +/// Generate a collect poll operation. +fn arbitrary_collector_poll_op(g: &mut Gen, context: &Context) -> Op { + Op::CollectorPoll { + collection_job_id: g + .choose(&context.started_collection_job_ids) + .copied() + .unwrap_or_else(|| { + CollectionJobId::try_from([0u8; CollectionJobId::LEN].as_slice()).unwrap() + }), + } +} + +impl Arbitrary for TimeIntervalInput { + fn arbitrary(g: &mut Gen) -> Self { + let config = Config::arbitrary(g); + let context = Context::new(&config); + let ops = arbitrary_vec_with_context( + arbitrary_op_time_interval, + g, + context, + choices::OP_KIND_CHOICES, + ); + Self(Input { + is_fixed_size: false, + config, + ops, + }) + } + + fn shrink(&self) -> Box> { + Box::new(shrink_ops(&self.0.ops).map({ + let config = self.0.config.clone(); + let is_fixed_size = self.0.is_fixed_size; + move |ops| { + Self(Input { + config: config.clone(), + ops, + is_fixed_size, + }) + } + })) + } +} + +enum OpKind { + AdvanceTime, + Upload, + UploadReplay, + LeaderGarbageCollector, + HelperGarbageCollector, + LeaderKeyRotator, + HelperKeyRotator, + AggregationJobCreator, + AggregationJobDriver, + AggregationJobDriverRequestError, + AggregationJobDriverResponseError, + CollectionJobDriver, + CollectionJobDriverRequestError, + CollectionJobDriverResponseError, + CollectorStart, + CollectorPoll, +} + +/// Arrays of kinds of operations. These will be used with [`Gen::choice`] to select random +/// operations. Some operations are listed multiple times to bias operation selection. +mod choices { + use super::{OpKind, OpKind::*}; + pub(super) static OP_KIND_CHOICES: &[OpKind] = &[ + AdvanceTime, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + UploadReplay, + LeaderGarbageCollector, + HelperGarbageCollector, + LeaderKeyRotator, + HelperKeyRotator, + AggregationJobCreator, + AggregationJobDriver, + CollectionJobDriver, + CollectorStart, + CollectorPoll, + ]; + pub(super) static OP_KIND_CHOICES_FAULT_INJECTION: &[OpKind] = &[ + AdvanceTime, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + Upload, + UploadReplay, + LeaderGarbageCollector, + HelperGarbageCollector, + LeaderKeyRotator, + HelperKeyRotator, + AggregationJobCreator, + AggregationJobDriver, + AggregationJobDriverRequestError, + AggregationJobDriverResponseError, + CollectionJobDriver, + CollectionJobDriverRequestError, + CollectionJobDriverResponseError, + CollectorStart, + CollectorPoll, + ]; +} + +/// Generate an operation, using time interval queries. +fn arbitrary_op_time_interval(g: &mut Gen, context: &Context, choices: &[OpKind]) -> Op { + match g.choose(choices).unwrap() { + OpKind::AdvanceTime => Op::AdvanceTime { + amount: Duration::from_seconds(u16::arbitrary(g).into()), + }, + OpKind::Upload => arbitrary_upload_op(g, context), + OpKind::UploadReplay => arbitrary_upload_replay_op(g, context), + OpKind::LeaderGarbageCollector => Op::LeaderGarbageCollector, + OpKind::HelperGarbageCollector => Op::HelperGarbageCollector, + OpKind::LeaderKeyRotator => Op::LeaderKeyRotator, + OpKind::HelperKeyRotator => Op::HelperKeyRotator, + OpKind::AggregationJobCreator => Op::AggregationJobCreator, + OpKind::AggregationJobDriver => Op::AggregationJobDriver, + OpKind::AggregationJobDriverRequestError => Op::AggregationJobDriverRequestError, + OpKind::AggregationJobDriverResponseError => Op::AggregationJobDriverResponseError, + OpKind::CollectionJobDriver => Op::CollectionJobDriver, + OpKind::CollectionJobDriverRequestError => Op::CollectionJobDriverRequestError, + OpKind::CollectionJobDriverResponseError => Op::CollectionJobDriverResponseError, + OpKind::CollectorStart => arbitrary_collector_start_op_time_interval(g, context), + OpKind::CollectorPoll => arbitrary_collector_poll_op(g, context), + } +} + +impl Arbitrary for FixedSizeInput { + fn arbitrary(g: &mut Gen) -> Self { + let config = Config::arbitrary(g); + let context = Context::new(&config); + let ops = arbitrary_vec_with_context( + arbitrary_op_fixed_size, + g, + context, + choices::OP_KIND_CHOICES, + ); + Self(Input { + is_fixed_size: true, + config, + ops, + }) + } + + fn shrink(&self) -> Box> { + Box::new(shrink_ops(&self.0.ops).map({ + let config = self.0.config.clone(); + let is_fixed_size = self.0.is_fixed_size; + move |ops| { + Self(Input { + config: config.clone(), + ops, + is_fixed_size, + }) + } + })) + } +} + +/// Generate an operation, using fixed size queries. +fn arbitrary_op_fixed_size(g: &mut Gen, context: &Context, choices: &[OpKind]) -> Op { + match g.choose(choices).unwrap() { + OpKind::AdvanceTime => Op::AdvanceTime { + amount: Duration::from_seconds(u16::arbitrary(g).into()), + }, + OpKind::Upload => arbitrary_upload_op(g, context), + OpKind::UploadReplay => arbitrary_upload_replay_op(g, context), + OpKind::LeaderGarbageCollector => Op::LeaderGarbageCollector, + OpKind::HelperGarbageCollector => Op::HelperGarbageCollector, + OpKind::LeaderKeyRotator => Op::LeaderKeyRotator, + OpKind::HelperKeyRotator => Op::HelperKeyRotator, + OpKind::AggregationJobCreator => Op::AggregationJobCreator, + OpKind::AggregationJobDriver => Op::AggregationJobDriver, + OpKind::AggregationJobDriverRequestError => Op::AggregationJobDriverRequestError, + OpKind::AggregationJobDriverResponseError => Op::AggregationJobDriverResponseError, + OpKind::CollectionJobDriver => Op::CollectionJobDriver, + OpKind::CollectionJobDriverRequestError => Op::CollectionJobDriverRequestError, + OpKind::CollectionJobDriverResponseError => Op::CollectionJobDriverResponseError, + OpKind::CollectorStart => arbitrary_collector_start_op_fixed_size(g, context), + OpKind::CollectorPoll => arbitrary_collector_poll_op(g, context), + } +} + +impl Arbitrary for TimeIntervalFaultInjectionInput { + fn arbitrary(g: &mut Gen) -> Self { + let config = Config::arbitrary(g); + let context = Context::new(&config); + let ops = arbitrary_vec_with_context( + arbitrary_op_time_interval, + g, + context, + choices::OP_KIND_CHOICES_FAULT_INJECTION, + ); + Self(Input { + is_fixed_size: false, + config, + ops, + }) + } + + fn shrink(&self) -> Box> { + Box::new(shrink_ops(&self.0.ops).map({ + let config = self.0.config.clone(); + let is_fixed_size = self.0.is_fixed_size; + move |ops| { + Self(Input { + config: config.clone(), + ops, + is_fixed_size, + }) + } + })) + } +} + +impl Arbitrary for FixedSizeFaultInjectionInput { + fn arbitrary(g: &mut Gen) -> Self { + let config = Config::arbitrary(g); + let context = Context::new(&config); + let ops = arbitrary_vec_with_context( + arbitrary_op_fixed_size, + g, + context, + choices::OP_KIND_CHOICES_FAULT_INJECTION, + ); + Self(Input { + is_fixed_size: true, + config, + ops, + }) + } + + fn shrink(&self) -> Box> { + Box::new(shrink_ops(&self.0.ops).map({ + let config = self.0.config.clone(); + let is_fixed_size = self.0.is_fixed_size; + move |ops| { + Self(Input { + config: config.clone(), + ops, + is_fixed_size, + }) + } + })) + } +} diff --git a/integration_tests/tests/integration/simulation/mod.rs b/integration_tests/tests/integration/simulation/mod.rs new file mode 100644 index 000000000..f149c93d4 --- /dev/null +++ b/integration_tests/tests/integration/simulation/mod.rs @@ -0,0 +1,84 @@ +//! This integration test attempts to implement discrete event simulation in Janus. The goal of this +//! test is to uncover more bugs around garbage collection, network errors, runtime errors, or other +//! failures to uphold certain DAP invariants. We will try to avoid introducing nondeterminism as +//! much as possible, given the current architecture, though we will necessarily fall short of +//! complete determinism. +//! +//! The operating system, Postgres, and Rust synchronization primitives are outside the control of +//! our simulation, so nondeterminism introduced by these sources is out of scope. While this set of +//! tests may trigger bugs that involve these phenomena, it will not be able to reproduce them +//! reliably like in-scope bugs. There will be room to fix some sources of nondeterminism short of +//! these limits, in order to reproduce more bugs repeatably. For example, operations can be +//! serialized, to avoid running multiple database transactions concurrently, and parallelism in the +//! aggregation job driver and collection job driver can be eliminated. +//! +//! Prio3Histogram is the only VDAF used, and reports are carefully crafted to allow aggregate +//! results to be verified, even in the face of partial data loss. Each report will have a unique +//! measurement, so correct aggregate results must consist of only zeros and ones (and moreover, +//! they should be all zeros for buckets that were never submitted in reports). A two in any +//! position would indicate that a report was replayed or counted twice, while very large numbers +//! would suggest an undetected batch mismatch between the leader and helper, or incorrect +//! aggregation by one of the aggregators. There will only be one DAP task in use at a time. Both +//! TimeInterval and FixedSize query types should be supported, as their implementations are very +//! different. +//! +//! The simulation consists of multiple components making up a client, two aggregators, and a +//! collector. All components for each aggregator share a database, and all components across the +//! simulation share a `MockClock`. None of the components should run any asynchronous tasks +//! continuously throughout the simulation (except for tokio-postgres connection tasks). +//! Initialization will be akin to `JanusInProcessPair`, but more low-level. The simulation is fed a +//! list of [`Op`](model::Op) values, and it executes the operations described one after another. +//! +//! The following are possible failure conditions: +//! - The main Tokio task panics while calling into any component. +//! - Any spawned Tokio task managed by a `TestRuntime` panics. +//! - The collector gets an aggregate result that is impossible. +//! * Any array element is greater than one. +//! * Any array element is nonzero and none of the reports contributed to that bucket. +//! - The collector gets multiple successful responses when polling the same collection job (across +//! multiple operations) and they are not equal. +//! - The helper sends an error response with a "batch mismatch" problem type in response to an +//! aggregate share request from the leader. +//! - The leader sends two aggregation job initialization requests with the same ID, but different +//! contents. +//! - An individual operation exceeds some timeout. +//! +//! The following are explicitly not failure conditions: +//! - The client gets an error from the leader when trying to upload a report, because the timestamp +//! is too old or too new. +//! - The collector gets an error back from the leader indicating a batch can't be collected yet. +//! +//! Note that, due to known issues, Janus would currently fail a liveness criteria like "if a report +//! is uploaded with a timestamp near enough the current time, and the leader's components run +//! enough times before time advances too much, and its batch was collected after the report was +//! uploaded, then it should show up in the results." In particular, it's possible for fresh reports +//! to be combined with about-to-expire reports in aggregation jobs, and in such cases the fresh +//! report would be lost if time advanced a small amount before aggregation happened. +//! +//! It may be possible to impose a collection job liveness criteria, along the lines of "if the +//! aggregation job driver runs 'enough', then the collection job driver runs 'enough', then a +//! collection job request is polled, the job should either finish or fail." +//! +//! ## Known sources of nondeterminism +//! +//! - Timing of network syscalls. +//! - Database anomalies allowed at REPEATABLE READ. +//! - Stochastic behavior of Postgres query planner (leading to different row orders). +//! - `SKIP LOCKED` clauses in database queries. +//! - Any parallelization of database-related futures. +//! - Randomly-selected `ord` values in tables with "sharded" rows. +//! - Timing of asynchronous tasks. +//! - Randomness used by `tokio::select!`. +//! - Application-level concurrency bugs. + +use janus_messages::Time; + +const START_TIME: Time = Time::from_seconds_since_epoch(1_700_000_000); + +mod arbitrary; +mod model; +mod proxy; +mod quicktest; +mod reproduction; +mod run; +mod setup; diff --git a/integration_tests/tests/integration/simulation/model.rs b/integration_tests/tests/integration/simulation/model.rs new file mode 100644 index 000000000..6846f92b2 --- /dev/null +++ b/integration_tests/tests/integration/simulation/model.rs @@ -0,0 +1,118 @@ +use janus_messages::{CollectionJobId, Duration, Interval, Time}; + +#[derive(Debug, Clone)] +pub(super) struct Input { + /// Task query type selector. This is fixed by the test harness, and not randomly generated. + pub(super) is_fixed_size: bool, + + /// Combination of Janus configuration and task parameters. + pub(super) config: Config, + + /// Simulation operations to run. + pub(super) ops: Vec, +} + +#[derive(Debug, Clone)] +pub(super) struct Config { + /// DAP task parameter: time precision. + pub(super) time_precision: Duration, + + /// DAP task parameter: minimum batch size. + pub(super) min_batch_size: u64, + + /// DAP task parameter: maximum batch size. This is only used with fixed size tasks, and ignored + /// otherwise. + pub(super) max_batch_size: Option, + + /// Janus-specific task parameter: batch time window size (for the time-bucketed fixed size + /// feature). This is only used with fixed size tasks, and ignored otherwise. + pub(super) batch_time_window_size: Option, + + /// Janus-specific task parameter: report expiry age (for garbage collection). + pub(super) report_expiry_age: Option, + + /// Aggregation job creator configuration: minimum aggregation job size. + pub(super) min_aggregation_job_size: usize, + + /// Aggregation job creator configuration: maximum aggregation job size. + pub(super) max_aggregation_job_size: usize, +} + +#[derive(Debug, Clone)] +pub(super) enum Op { + /// Advance the `MockClock`'s time by `amount`. + AdvanceTime { amount: Duration }, + + /// Have the client shard a report at the given timestamp, with the next sequential measurement, + /// and send it to the leader aggregator. The leader will handle the request and store the + /// report to the database. Note that, as currently implemented, this will wait for the report + /// batching timeout to expire, so the client's upload method won't return until the leader's + /// database transaction is complete. + Upload { report_time: Time }, + + /// Have the client shard a report at the given timestamp as with `Upload`, but with a fixed + /// report ID. + UploadReplay { report_time: Time }, + + /// Run the garbage collector once in the leader. + LeaderGarbageCollector, + + /// Run the garbage collector once in the helper. + HelperGarbageCollector, + + /// Run the key rotator once in the leader. + LeaderKeyRotator, + + /// Run the key rotator once in the helper. + HelperKeyRotator, + + /// Run the aggregation job creator once. + AggregationJobCreator, + + /// Run the aggregation job driver once, and wait until it is done stepping all the jobs it + /// acquired. Requests and responses will pass through an inspecting proxy in front of the + /// helper. + AggregationJobDriver, + + /// Same as `AggregationJobDriver`, with fault injection. Drop all requests and return some sort + /// of error. + AggregationJobDriverRequestError, + + /// Same as `AggregationJobDriver`, with fault injection. Forward all requests, but drop the + /// responses, and return some sort of error. + AggregationJobDriverResponseError, + + /// Run the collection job driver once, and wait until it is done stepping all the jobs it + /// acquired. Requests and responses will pass through an inspecting proxy in front of the + /// helper. + CollectionJobDriver, + + /// Same as `CollectionJobDriver`, with fault injection. Drop all requests and return some sort + /// of error. + CollectionJobDriverRequestError, + + /// Same as `CollectionJobDriver`, with fault injection. Forward all requests, but drop the + /// responses, and return some sort of error. + CollectionJobDriverResponseError, + + /// The collector sends a collection request to the leader. It remembers the collection job ID. + CollectorStart { + collection_job_id: CollectionJobId, + query: Query, + }, + + /// The collector sends a request to the leader to poll an existing collection job. + CollectorPoll { collection_job_id: CollectionJobId }, +} + +/// Representation of a DAP query used in a collection job. +#[derive(Debug, Clone)] +pub(super) enum Query { + /// A time interval query, parameterized with a batch interval. + TimeInterval(Interval), + /// A current batch query. + FixedSizeCurrentBatch, + /// A "by batch ID" query. The batch ID will be taken from a previous collection result, with + /// the given collection job ID. + FixedSizeByBatchId(CollectionJobId), +} diff --git a/integration_tests/tests/integration/simulation/proxy.rs b/integration_tests/tests/integration/simulation/proxy.rs new file mode 100644 index 000000000..7e1c22902 --- /dev/null +++ b/integration_tests/tests/integration/simulation/proxy.rs @@ -0,0 +1,178 @@ +use std::{ + borrow::Cow, + sync::{Arc, Mutex, OnceLock}, +}; + +use regex::bytes::Regex; +use tracing::error; +use trillium::{Conn, Handler, Status}; +use trillium_macros::Handler; + +/// A [`Handler`] wrapper that can be configured to drop requests or responses. +#[derive(Handler)] +pub(super) struct FaultInjectorHandler { + #[handler(except = [run, before_send, name])] + inner: H, + + /// Flag to inject an error before request handling. This will skip running the wrapped + /// `Handler`. + error_before: Arc>, + + /// Flag to inject an error after request handling. This will drop the response and replace it + /// with an error response. + error_after: Arc>, +} + +impl FaultInjectorHandler { + pub fn new(handler: H) -> Self { + Self { + inner: handler, + error_before: Arc::new(Mutex::new(false)), + error_after: Arc::new(Mutex::new(false)), + } + } + + pub fn controller(&self) -> FaultInjector { + FaultInjector { + error_before: Arc::clone(&self.error_before), + error_after: Arc::clone(&self.error_after), + } + } +} + +struct FaultInjectorMarker; + +impl FaultInjectorHandler { + async fn run(&self, mut conn: Conn) -> Conn { + conn.insert_state(FaultInjectorMarker); + if *self.error_before.lock().unwrap() { + conn.with_status(Status::InternalServerError) + } else { + self.inner.run(conn).await + } + } + + async fn before_send(&self, conn: Conn) -> Conn { + let mut conn = self.inner.before_send(conn).await; + if conn.state::().is_some() && *self.error_after.lock().unwrap() { + conn.set_status(Status::InternalServerError); + let header_names = conn + .response_headers() + .iter() + .map(|(name, _)| name.to_owned()) + .collect::>(); + conn.response_headers_mut().remove_all(header_names); + conn.set_body(""); + } + conn + } + + fn name(&self) -> Cow<'static, str> { + format!("FaultInjectorHandler({})", std::any::type_name::()).into() + } +} + +/// This controls a [`FaultInjectorHandler`]. +pub(super) struct FaultInjector { + error_before: Arc>, + error_after: Arc>, +} + +impl FaultInjector { + /// Disable all fault injection. + pub fn reset(&self) { + *self.error_before.lock().unwrap() = false; + *self.error_after.lock().unwrap() = false; + } + + /// Inject an error before request handling. This will skip running the wrapped `Handler`. + pub fn error_before(&self) { + *self.error_before.lock().unwrap() = true; + } + + /// Inject an error after request handling. This will drop the response and replace it with an + /// error response. + pub fn error_after(&self) { + *self.error_after.lock().unwrap() = true; + } +} + +/// A [`Handler`] wrapper that inspects request and response bodies, in order to trigger test failures. +#[derive(Handler)] +pub(super) struct InspectHandler { + #[handler(except = [run, before_send, name])] + inner: H, + failure: Arc>, +} + +impl InspectHandler { + pub fn new(handler: H) -> Self { + Self { + inner: handler, + failure: Arc::new(Mutex::new(false)), + } + } + + pub fn monitor(&self) -> InspectMonitor { + InspectMonitor { + failure: Arc::clone(&self.failure), + } + } +} + +struct InspectMarker; + +impl InspectHandler { + async fn run(&self, mut conn: Conn) -> Conn { + conn.insert_state(InspectMarker); + self.inner.run(conn).await + } + + async fn before_send(&self, conn: Conn) -> Conn { + let mut conn = self.inner.before_send(conn).await; + if conn.state::().is_some() { + if conn.status() == Some(Status::Conflict) { + error!("409 Conflict response"); + *self.failure.lock().unwrap() = true; + } + if conn.path().ends_with("/aggregate_shares") { + inspect_response_body(&mut conn, |bytes| { + static ONCE: OnceLock = OnceLock::new(); + let batch_mismatch_regex = ONCE.get_or_init(|| { + Regex::new("urn:ietf:params:ppm:dap:error:batchMismatch").unwrap() + }); + if batch_mismatch_regex.is_match(bytes) { + error!("batch mismatch response"); + *self.failure.lock().unwrap() = true; + } + }) + .await; + } + } + conn + } + + fn name(&self) -> Cow<'static, str> { + format!("InspectHandler({})", std::any::type_name::()).into() + } +} + +/// Takes the response body from a connection, runs the provided closure on it, and replaces the +/// response body. If no body has been set yet, the closure is not run. +async fn inspect_response_body(conn: &mut Conn, f: impl Fn(&[u8])) { + if let Some(body) = conn.take_response_body() { + let bytes = body.into_bytes().await.unwrap(); + f(&bytes); + conn.set_body(bytes); + } +} + +pub(super) struct InspectMonitor { + failure: Arc>, +} + +impl InspectMonitor { + pub fn has_failed(&self) -> bool { + *self.failure.lock().unwrap() + } +} diff --git a/integration_tests/tests/integration/simulation/quicktest.rs b/integration_tests/tests/integration/simulation/quicktest.rs new file mode 100644 index 000000000..3f3a24f44 --- /dev/null +++ b/integration_tests/tests/integration/simulation/quicktest.rs @@ -0,0 +1,52 @@ +use janus_core::test_util::install_test_trace_subscriber; +use quickcheck::{QuickCheck, TestResult}; + +use crate::simulation::{ + arbitrary::{ + FixedSizeFaultInjectionInput, FixedSizeInput, TimeIntervalFaultInjectionInput, + TimeIntervalInput, + }, + run::Simulation, +}; + +#[test] +#[ignore = "slow quickcheck test"] +fn simulation_test_time_interval_no_fault_injection() { + install_test_trace_subscriber(); + + QuickCheck::new().quickcheck( + (|TimeIntervalInput(input)| Simulation::run(input)) as fn(TimeIntervalInput) -> TestResult, + ); +} + +#[test] +#[ignore = "slow quickcheck test"] +fn simulation_test_fixed_size_no_fault_injection() { + install_test_trace_subscriber(); + + QuickCheck::new().quickcheck( + (|FixedSizeInput(input)| Simulation::run(input)) as fn(FixedSizeInput) -> TestResult, + ); +} + +#[test] +#[ignore = "slow quickcheck test"] +fn simulation_test_time_interval_with_fault_injection() { + install_test_trace_subscriber(); + + QuickCheck::new().quickcheck( + (|TimeIntervalFaultInjectionInput(input)| Simulation::run(input)) + as fn(TimeIntervalFaultInjectionInput) -> TestResult, + ); +} + +#[test] +#[ignore = "slow quickcheck test"] +fn simulation_test_fixed_size_with_fault_injection() { + install_test_trace_subscriber(); + + QuickCheck::new().quickcheck( + (|FixedSizeFaultInjectionInput(input)| Simulation::run(input)) + as fn(FixedSizeFaultInjectionInput) -> TestResult, + ); +} diff --git a/integration_tests/tests/integration/simulation/reproduction.rs b/integration_tests/tests/integration/simulation/reproduction.rs new file mode 100644 index 000000000..758e930ba --- /dev/null +++ b/integration_tests/tests/integration/simulation/reproduction.rs @@ -0,0 +1,384 @@ +use janus_core::{test_util::install_test_trace_subscriber, time::TimeExt}; +use janus_messages::{Duration, Interval, Time}; +use rand::random; + +use crate::simulation::{ + model::{Config, Input, Op, Query}, + run::Simulation, + START_TIME, +}; + +#[test] +fn successful_collection_time_interval() { + install_test_trace_subscriber(); + + let collection_job_id = random(); + let input = Input { + is_fixed_size: false, + config: Config { + time_precision: Duration::from_seconds(3600), + min_batch_size: 4, + max_batch_size: None, + batch_time_window_size: None, + report_expiry_age: Some(Duration::from_seconds(7200)), + min_aggregation_job_size: 1, + max_aggregation_job_size: 10, + }, + ops: Vec::from([ + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::LeaderGarbageCollector, + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::LeaderGarbageCollector, + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::LeaderGarbageCollector, + Op::CollectorStart { + collection_job_id, + query: Query::TimeInterval( + Interval::new( + Time::from_seconds_since_epoch(1_699_999_200), + Duration::from_seconds(3600), + ) + .unwrap(), + ), + }, + Op::CollectionJobDriver, + Op::CollectorPoll { collection_job_id }, + Op::Upload { + report_time: START_TIME, + }, + Op::Upload { + report_time: START_TIME, + }, + Op::Upload { + report_time: START_TIME, + }, + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::CollectorStart { + collection_job_id, + query: Query::TimeInterval( + Interval::new( + Time::from_seconds_since_epoch(1_699_999_200), + Duration::from_seconds(3600), + ) + .unwrap(), + ), + }, + Op::CollectionJobDriver, + Op::CollectorPoll { collection_job_id }, + ]), + }; + assert!(!Simulation::run(input).is_failure()); +} + +#[test] +fn successful_collection_fixed_size() { + install_test_trace_subscriber(); + + let collection_job_id = random(); + let input = Input { + is_fixed_size: true, + config: Config { + time_precision: Duration::from_seconds(3600), + min_batch_size: 4, + max_batch_size: Some(6), + batch_time_window_size: None, + report_expiry_age: Some(Duration::from_seconds(7200)), + min_aggregation_job_size: 1, + max_aggregation_job_size: 10, + }, + ops: Vec::from([ + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::LeaderGarbageCollector, + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::LeaderGarbageCollector, + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::LeaderGarbageCollector, + Op::CollectorStart { + collection_job_id, + query: Query::FixedSizeCurrentBatch, + }, + Op::CollectionJobDriver, + Op::CollectorPoll { collection_job_id }, + Op::Upload { + report_time: START_TIME, + }, + Op::Upload { + report_time: START_TIME, + }, + Op::Upload { + report_time: START_TIME, + }, + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::CollectorStart { + collection_job_id, + query: Query::FixedSizeCurrentBatch, + }, + Op::CollectionJobDriver, + Op::CollectorPoll { collection_job_id }, + ]), + }; + assert!(!Simulation::run(input).is_failure()); +} + +#[test] +#[ignore = "failing test"] +/// Reproduction of https://github.com/divviup/janus/issues/3323. +fn repro_slow_uploads_with_max_batch_size() { + install_test_trace_subscriber(); + + let collection_job_id = random(); + let input = Input { + is_fixed_size: true, + config: Config { + time_precision: Duration::from_seconds(3600), + min_batch_size: 4, + max_batch_size: Some(6), + batch_time_window_size: None, + report_expiry_age: Some(Duration::from_seconds(7200)), + min_aggregation_job_size: 1, + max_aggregation_job_size: 10, + }, + ops: Vec::from([ + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::AdvanceTime { + amount: Duration::from_seconds(3600), + }, + Op::LeaderGarbageCollector, + Op::Upload { + report_time: Time::from_seconds_since_epoch(1_700_003_600), + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::AdvanceTime { + amount: Duration::from_seconds(3600), + }, + Op::LeaderGarbageCollector, + Op::Upload { + report_time: Time::from_seconds_since_epoch(1_700_007_200), + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::AdvanceTime { + amount: Duration::from_seconds(3600), + }, + Op::LeaderGarbageCollector, + Op::Upload { + report_time: Time::from_seconds_since_epoch(1_700_010_800), + }, + Op::Upload { + report_time: Time::from_seconds_since_epoch(1_700_010_800), + }, + Op::Upload { + report_time: Time::from_seconds_since_epoch(1_700_010_800), + }, + Op::Upload { + report_time: Time::from_seconds_since_epoch(1_700_010_800), + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::CollectorStart { + collection_job_id, + query: Query::FixedSizeCurrentBatch, + }, + Op::CollectionJobDriver, + Op::CollectorPoll { collection_job_id }, + ]), + }; + assert!(!Simulation::run(input).is_failure()); +} + +#[test] +/// Regression test for https://github.com/divviup/janus/issues/2442. +fn repro_gc_changes_aggregation_job_retry_time_interval() { + install_test_trace_subscriber(); + + let input = Input { + is_fixed_size: false, + config: Config { + time_precision: Duration::from_seconds(3600), + min_batch_size: 1, + max_batch_size: None, + batch_time_window_size: None, + report_expiry_age: Some(Duration::from_seconds(7200)), + min_aggregation_job_size: 2, + max_aggregation_job_size: 2, + }, + ops: Vec::from([ + Op::Upload { + report_time: START_TIME, + }, + Op::AdvanceTime { + amount: Duration::from_seconds(3600), + }, + Op::Upload { + report_time: START_TIME.add(&Duration::from_seconds(3600)).unwrap(), + }, + Op::AggregationJobCreator, + Op::AggregationJobDriverResponseError, + Op::AdvanceTime { + amount: Duration::from_seconds(5400), + }, + Op::LeaderGarbageCollector, + Op::AggregationJobDriver, + ]), + }; + assert!(!Simulation::run(input).is_failure()); +} + +#[test] +/// Regression test for https://github.com/divviup/janus/issues/2442. +fn repro_gc_changes_aggregation_job_retry_fixed_size() { + install_test_trace_subscriber(); + + let input = Input { + is_fixed_size: true, + config: Config { + time_precision: Duration::from_seconds(3600), + min_batch_size: 1, + max_batch_size: None, + batch_time_window_size: None, + report_expiry_age: Some(Duration::from_seconds(7200)), + min_aggregation_job_size: 2, + max_aggregation_job_size: 2, + }, + ops: Vec::from([ + Op::Upload { + report_time: START_TIME, + }, + Op::AdvanceTime { + amount: Duration::from_seconds(3600), + }, + Op::Upload { + report_time: START_TIME.add(&Duration::from_seconds(3600)).unwrap(), + }, + Op::AggregationJobCreator, + Op::AggregationJobDriverResponseError, + Op::AdvanceTime { + amount: Duration::from_seconds(5400), + }, + Op::LeaderGarbageCollector, + Op::AggregationJobDriver, + ]), + }; + assert!(!Simulation::run(input).is_failure()); +} + +#[test] +/// Regression test for https://github.com/divviup/janus/issues/2464. +fn repro_recreate_gcd_batch_job_count_underflow() { + install_test_trace_subscriber(); + + let input = Input { + is_fixed_size: false, + config: Config { + time_precision: Duration::from_seconds(1000), + min_batch_size: 100, + max_batch_size: None, + batch_time_window_size: None, + report_expiry_age: Some(Duration::from_seconds(4000)), + min_aggregation_job_size: 2, + max_aggregation_job_size: 2, + }, + ops: Vec::from([ + Op::Upload { + report_time: START_TIME, + }, + Op::AdvanceTime { + amount: Duration::from_seconds(2000), + }, + Op::Upload { + report_time: START_TIME.add(&Duration::from_seconds(2000)).unwrap(), + }, + Op::AggregationJobCreator, + Op::AdvanceTime { + amount: Duration::from_seconds(3500), + }, + Op::AggregationJobDriver, + ]), + }; + assert!(!Simulation::run(input).is_failure()); +} + +#[test] +#[ignore = "failing test"] +fn repro_abandoned_aggregation_job_batch_mismatch() { + install_test_trace_subscriber(); + + let collection_job_id = random(); + let input = Input { + is_fixed_size: false, + config: Config { + time_precision: Duration::from_seconds(1000), + min_batch_size: 1, + max_batch_size: None, + batch_time_window_size: None, + report_expiry_age: None, + min_aggregation_job_size: 1, + max_aggregation_job_size: 1, + }, + ops: Vec::from([ + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriver, + Op::Upload { + report_time: START_TIME, + }, + Op::AggregationJobCreator, + Op::AggregationJobDriverResponseError, + Op::AdvanceTime { + amount: Duration::from_seconds(610), + }, + Op::AggregationJobDriverResponseError, + Op::AdvanceTime { + amount: Duration::from_seconds(610), + }, + Op::AggregationJobDriver, + Op::CollectorStart { + collection_job_id, + query: Query::TimeInterval( + Interval::new(START_TIME, Duration::from_seconds(1000)).unwrap(), + ), + }, + Op::CollectionJobDriver, + ]), + }; + assert!(!Simulation::run(input).is_failure()); +} diff --git a/integration_tests/tests/integration/simulation/run.rs b/integration_tests/tests/integration/simulation/run.rs new file mode 100644 index 000000000..70dd75a52 --- /dev/null +++ b/integration_tests/tests/integration/simulation/run.rs @@ -0,0 +1,662 @@ +use std::{ + collections::HashMap, + ops::ControlFlow, + panic::{catch_unwind, AssertUnwindSafe}, + sync::Arc, + time::Duration as StdDuration, +}; + +use backoff::ExponentialBackoff; +use derivative::Derivative; +use divviup_client::{Decode, Encode}; +use http::header::CONTENT_TYPE; +use janus_aggregator::aggregator; +use janus_aggregator_core::{ + datastore::models::AggregatorRole, + task::{test_util::Task, AggregatorTask}, + test_util::noop_meter, +}; +use janus_collector::{Collection, CollectionJob, PollResult}; +use janus_core::{ + hpke::{self, HpkeApplicationInfo, Label}, + http::HttpErrorResponse, + retries::retry_http_request, + test_util::runtime::TestRuntimeManager, + time::{Clock, MockClock, TimeExt}, + vdaf::{vdaf_dp_strategies, VdafInstance}, +}; +use janus_messages::{ + query_type::{FixedSize, TimeInterval}, + CollectionJobId, Duration, FixedSizeQuery, HpkeConfig, HpkeConfigList, InputShareAad, + PlaintextInputShare, Report, ReportId, ReportMetadata, Role, Time, +}; +use opentelemetry::metrics::Meter; +use prio::vdaf::{ + prio3::{optimal_chunk_length, Prio3, Prio3Histogram}, + Client as _, +}; +use quickcheck::TestResult; +use tokio::time::timeout; +use tracing::{debug, error, info, info_span, warn, Instrument}; +use trillium_tokio::Stopper; +use url::Url; + +use crate::simulation::{ + model::{Input, Op, Query}, + setup::Components, + START_TIME, +}; + +const MAX_REPORTS: usize = 1_000; + +pub(super) struct Simulation { + state: State, + components: Components, + task: Task, + leader_task: Arc, +} + +impl Simulation { + async fn new(input: &Input) -> Self { + let mut state = State::new(); + let (components, task) = Components::setup(input, &mut state).await; + let leader_task = Arc::new(task.leader_view().unwrap()); + Self { + state, + components, + task, + leader_task, + } + } + + pub(super) fn run(input: Input) -> TestResult { + let tokio_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + tokio_runtime.block_on(async { + let mut simulation = Self::new(&input).await; + for op in input.ops.iter() { + let span = info_span!("operation", op = ?op); + let timeout_result = timeout(StdDuration::from_secs(15), async { + info!(time = ?simulation.state.clock.now(), "starting operation"); + let result = match op { + Op::AdvanceTime { amount } => simulation.execute_advance_time(amount).await, + Op::Upload { report_time } => simulation.execute_upload(report_time).await, + Op::UploadReplay { report_time } => { + simulation.execute_upload_replay(report_time).await + } + Op::LeaderGarbageCollector => { + simulation + .execute_garbage_collector(AggregatorRole::Leader) + .await + } + Op::HelperGarbageCollector => { + simulation + .execute_garbage_collector(AggregatorRole::Helper) + .await + } + Op::LeaderKeyRotator => { + simulation.execute_key_rotator(AggregatorRole::Leader).await + } + Op::HelperKeyRotator => { + simulation.execute_key_rotator(AggregatorRole::Helper).await + } + Op::AggregationJobCreator => { + simulation.execute_aggregation_job_creator().await + } + Op::AggregationJobDriver => { + simulation.execute_aggregation_job_driver().await + } + Op::AggregationJobDriverRequestError => { + simulation + .execute_aggregation_job_driver_request_error() + .await + } + Op::AggregationJobDriverResponseError => { + simulation + .execute_aggregation_job_driver_response_error() + .await + } + Op::CollectionJobDriver => simulation.execute_collection_job_driver().await, + Op::CollectionJobDriverRequestError => { + simulation + .execute_collection_job_driver_request_error() + .await + } + Op::CollectionJobDriverResponseError => { + simulation + .execute_collection_job_driver_response_error() + .await + } + Op::CollectorStart { + collection_job_id, + query, + } => { + simulation + .execute_collector_start(collection_job_id, query) + .await + } + Op::CollectorPoll { collection_job_id } => { + simulation.execute_collector_poll(collection_job_id).await + } + }; + info!("finished operation"); + result + }) + .instrument(span) + .await; + match timeout_result { + Ok(ControlFlow::Break(test_result)) => return test_result, + Ok(ControlFlow::Continue(())) => {} + Err(error) => return TestResult::error(error.to_string()), + } + + if simulation.components.leader.inspect_monitor.has_failed() + || simulation.components.helper.inspect_monitor.has_failed() + { + return TestResult::failed(); + } + } + + if !check_aggregate_results_valid( + &simulation.state.aggregate_results_time_interval, + &simulation.state, + ) { + return TestResult::failed(); + } + + if !check_aggregate_results_valid( + &simulation.state.aggregate_results_fixed_size, + &simulation.state, + ) { + return TestResult::failed(); + } + + // `TestRuntimeManager` will panic on drop if any asynchronous task spawned via its + // labeled runtimes panicked. Drop the `Simulation` struct, which includes this manager, + // inside `catch_unwind` so we can report failure. + if catch_unwind(AssertUnwindSafe(move || drop(simulation))).is_err() { + return TestResult::failed(); + } + + TestResult::passed() + }) + } + + async fn execute_advance_time(&mut self, amount: &Duration) -> ControlFlow { + self.state.clock.advance(amount); + ControlFlow::Continue(()) + } + + async fn execute_upload(&mut self, report_time: &Time) -> ControlFlow { + if let Some(measurement) = self.state.next_measurement() { + if let Err(error) = self + .components + .client + .upload_with_time(&measurement, *report_time) + .await + { + warn!(?error, "client error"); + // We expect to receive an error if the report timestamp is too far away from the + // current time, so we'll allow errors for now. + } + } + ControlFlow::Continue(()) + } + + async fn execute_upload_replay(&mut self, report_time: &Time) -> ControlFlow { + if let Some(measurement) = self.state.next_measurement() { + if let Err(error) = upload_replay_report( + measurement, + &self.task, + &self.state.vdaf, + report_time, + &self.components.http_client, + ) + .await + { + warn!(?error, "client error"); + // We expect to receive an error if the report timestamp is too far away from the + // current time, so we'll allow errors for now. + } + } + ControlFlow::Continue(()) + } + + async fn execute_garbage_collector(&mut self, role: AggregatorRole) -> ControlFlow { + let garbage_collector = match role { + AggregatorRole::Leader => &self.components.leader_garbage_collector, + AggregatorRole::Helper => &self.components.helper_garbage_collector, + }; + if let Err(error) = garbage_collector.run().await { + error!(?error, "garbage collector error"); + return ControlFlow::Break(TestResult::error(format!("{error:?}"))); + } + ControlFlow::Continue(()) + } + + async fn execute_key_rotator(&mut self, role: AggregatorRole) -> ControlFlow { + let key_rotator = match role { + AggregatorRole::Leader => &self.components.leader_key_rotator, + AggregatorRole::Helper => &self.components.helper_key_rotator, + }; + if let Err(error) = key_rotator.run().await { + error!(?error, "key rotator error"); + return ControlFlow::Break(TestResult::error(format!("{error:?}"))); + } + ControlFlow::Continue(()) + } + + async fn execute_aggregation_job_creator(&mut self) -> ControlFlow { + let aggregation_job_creator = Arc::clone(&self.components.aggregation_job_creator); + let task = Arc::clone(&self.leader_task); + if let Err(error) = aggregation_job_creator + .create_aggregation_jobs_for_task(task) + .await + { + error!(?error, "aggregation job creator error"); + return ControlFlow::Break(TestResult::error(format!("{error:?}"))); + } + ControlFlow::Continue(()) + } + + async fn execute_aggregation_job_driver(&mut self) -> ControlFlow { + let leases = match (self.components.aggregation_job_driver_acquirer_cb)(10).await { + Ok(leases) => leases, + Err(error) => { + error!(?error, "aggregation job driver error"); + return ControlFlow::Break(TestResult::error(format!("{error:?}"))); + } + }; + debug!(count = leases.len(), "acquired aggregation jobs"); + for lease in leases { + if let Err(error) = (self.components.aggregation_job_driver_stepper_cb)(lease).await { + if let aggregator::Error::Http(_) = error { + warn!(?error, "aggregation job driver error"); + return ControlFlow::Continue(()); + } + error!(?error, "aggregation job driver error"); + return ControlFlow::Break(TestResult::error(format!("{error:?}"))); + } + } + ControlFlow::Continue(()) + } + + async fn execute_aggregation_job_driver_request_error(&mut self) -> ControlFlow { + self.components.helper.fault_injector.error_before(); + let result = self.execute_aggregation_job_driver().await; + self.components.helper.fault_injector.reset(); + result + } + + async fn execute_aggregation_job_driver_response_error(&mut self) -> ControlFlow { + self.components.helper.fault_injector.error_after(); + let result = self.execute_aggregation_job_driver().await; + self.components.helper.fault_injector.reset(); + result + } + + async fn execute_collection_job_driver(&mut self) -> ControlFlow { + let leases = match (self.components.collection_job_driver_acquirer_cb)(10).await { + Ok(leases) => leases, + Err(error) => { + error!(?error, "collection job driver error"); + return ControlFlow::Break(TestResult::error(format!("{error:?}"))); + } + }; + debug!(count = leases.len(), "acquired collection jobs"); + for lease in leases { + if let Err(error) = (self.components.collection_job_driver_stepper_cb)(lease).await { + if let aggregator::Error::Http(_) = error { + warn!(?error, "collection job driver error"); + return ControlFlow::Continue(()); + } + error!(?error, "collection job driver error"); + return ControlFlow::Break(TestResult::error(format!("{error:?}"))); + } + } + ControlFlow::Continue(()) + } + + async fn execute_collection_job_driver_request_error(&mut self) -> ControlFlow { + self.components.helper.fault_injector.error_before(); + let result = self.execute_collection_job_driver().await; + self.components.helper.fault_injector.reset(); + result + } + + async fn execute_collection_job_driver_response_error(&mut self) -> ControlFlow { + self.components.helper.fault_injector.error_after(); + let result = self.execute_collection_job_driver().await; + self.components.helper.fault_injector.reset(); + result + } + + async fn execute_collector_start( + &mut self, + collection_job_id: &CollectionJobId, + query: &Query, + ) -> ControlFlow { + match query { + Query::TimeInterval(interval) => { + let query = janus_messages::Query::new_time_interval(*interval); + match self + .components + .collector + .start_collection_with_id(*collection_job_id, query, &()) + .await + { + Ok(collection_job) => { + self.state + .collection_jobs_time_interval + .insert(*collection_job_id, collection_job); + } + Err(error) => info!(?error, "collector error"), + } + } + Query::FixedSizeCurrentBatch => { + let query = janus_messages::Query::new_fixed_size(FixedSizeQuery::CurrentBatch); + match self + .components + .collector + .start_collection_with_id(*collection_job_id, query, &()) + .await + { + Ok(collection_job) => { + self.state + .collection_jobs_fixed_size + .insert(*collection_job_id, collection_job); + } + Err(error) => info!(?error, "collector error"), + } + } + Query::FixedSizeByBatchId(previous_collection_job_id) => { + if let Some(collection) = self + .state + .aggregate_results_fixed_size + .get(previous_collection_job_id) + { + let query = janus_messages::Query::new_fixed_size(FixedSizeQuery::ByBatchId { + batch_id: *collection.partial_batch_selector().batch_id(), + }); + match self + .components + .collector + .start_collection_with_id(*collection_job_id, query, &()) + .await + { + Ok(collection_job) => { + self.state + .collection_jobs_fixed_size + .insert(*collection_job_id, collection_job); + + // Store a copy of the collection results from the previous collection + // job under this new collection job as well. When we get results from + // pollng the "by batch ID" job, we will then compare results from the + // two jobs to ensure they are the same. + self.state.aggregate_results_fixed_size.insert( + *collection_job_id, + Collection::new( + collection.partial_batch_selector().clone(), + collection.report_count(), + *collection.interval(), + collection.aggregate_result().clone(), + ), + ); + } + Err(error) => info!(?error, "collector error"), + } + } + } + } + ControlFlow::Continue(()) + } + + async fn execute_collector_poll( + &mut self, + collection_job_id: &CollectionJobId, + ) -> ControlFlow { + if let Some(collection_job) = self + .state + .collection_jobs_time_interval + .get(collection_job_id) + { + let result = self.components.collector.poll_once(collection_job).await; + match result { + Ok(PollResult::CollectionResult(collection)) => { + let report_count = collection.report_count(); + let interval = *collection.interval(); + let aggregate_result = collection.aggregate_result().clone(); + let old_opt = self + .state + .aggregate_results_time_interval + .insert(*collection_job_id, collection); + if let Some(old_collection) = old_opt { + if report_count != old_collection.report_count() + || &interval != old_collection.interval() + || &aggregate_result != old_collection.aggregate_result() + { + error!("repeated collection did not match"); + return ControlFlow::Break(TestResult::failed()); + } + } + } + Ok(PollResult::NotReady(_)) => {} + Err(error) => info!(?error, "collector error"), + } + } else if let Some(collection_job) = + self.state.collection_jobs_fixed_size.get(collection_job_id) + { + let result = self.components.collector.poll_once(collection_job).await; + match result { + Ok(PollResult::CollectionResult(collection)) => { + let partial_batch_selector = collection.partial_batch_selector().clone(); + let report_count = collection.report_count(); + let interval = *collection.interval(); + let aggregate_result = collection.aggregate_result().clone(); + let old_opt = self + .state + .aggregate_results_fixed_size + .insert(*collection_job_id, collection); + if let Some(old_collection) = old_opt { + if &partial_batch_selector != old_collection.partial_batch_selector() + || report_count != old_collection.report_count() + || &interval != old_collection.interval() + || &aggregate_result != old_collection.aggregate_result() + { + error!("repeated collection did not match"); + return ControlFlow::Break(TestResult::failed()); + } + } + } + Ok(PollResult::NotReady(_)) => {} + Err(error) => info!(?error, "collector error"), + } + } + ControlFlow::Continue(()) + } +} + +#[derive(Derivative)] +#[derivative(Debug)] +pub(super) struct State { + pub(super) stopper: Stopper, + pub(super) clock: MockClock, + pub(super) meter: Meter, + #[derivative(Debug = "ignore")] + pub(super) runtime_manager: TestRuntimeManager<&'static str>, + pub(super) vdaf_instance: VdafInstance, + pub(super) vdaf: Prio3Histogram, + pub(super) collection_jobs_time_interval: + HashMap>, + pub(super) collection_jobs_fixed_size: HashMap>, + pub(super) aggregate_results_time_interval: + HashMap, TimeInterval>>, + pub(super) aggregate_results_fixed_size: + HashMap, FixedSize>>, + pub(super) next_measurement: usize, +} + +impl State { + fn new() -> Self { + let chunk_length = optimal_chunk_length(MAX_REPORTS); + Self { + stopper: Stopper::new(), + clock: MockClock::new(START_TIME), + meter: noop_meter(), + runtime_manager: TestRuntimeManager::new(), + vdaf_instance: VdafInstance::Prio3Histogram { + length: MAX_REPORTS, + chunk_length, + dp_strategy: vdaf_dp_strategies::Prio3Histogram::NoDifferentialPrivacy, + }, + vdaf: Prio3::new_histogram(2, MAX_REPORTS, chunk_length).unwrap(), + collection_jobs_time_interval: HashMap::new(), + collection_jobs_fixed_size: HashMap::new(), + aggregate_results_time_interval: HashMap::new(), + aggregate_results_fixed_size: HashMap::new(), + next_measurement: 0, + } + } + + fn next_measurement(&mut self) -> Option { + if self.next_measurement < MAX_REPORTS { + let output = self.next_measurement; + self.next_measurement += 1; + Some(output) + } else { + debug!("Too many reports, skipping upload operation"); + None + } + } +} + +/// Shard and upload a report, but with a fixed ReportId. +async fn upload_replay_report( + measurement: usize, + task: &Task, + vdaf: &Prio3Histogram, + report_time: &Time, + http_client: &reqwest::Client, +) -> Result<(), janus_client::Error> { + // This encodes to "replayreplayreplayrepl". + let report_id = ReportId::from([ + 173, 234, 101, 107, 42, 222, 166, 86, 178, 173, 234, 101, 107, 42, 222, 166, + ]); + let task_id = *task.id(); + let (public_share, input_shares) = vdaf.shard(&measurement, report_id.as_ref())?; + let rounded_time = report_time + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let report_metadata = ReportMetadata::new(report_id, rounded_time); + let encoded_public_share = public_share.get_encoded().unwrap(); + + let leader_hpke_config = + aggregator_hpke_config(task.leader_aggregator_endpoint(), http_client).await?; + let helper_hpke_config = + aggregator_hpke_config(task.helper_aggregator_endpoint(), http_client).await?; + + let aad = InputShareAad::new( + task_id, + report_metadata.clone(), + encoded_public_share.clone(), + ) + .get_encoded()?; + let leader_encrypted_input_share = hpke::seal( + &leader_hpke_config, + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader), + &PlaintextInputShare::new(Vec::new(), input_shares[0].get_encoded()?).get_encoded()?, + &aad, + )?; + let helper_encrypted_input_share = hpke::seal( + &helper_hpke_config, + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), + &PlaintextInputShare::new(Vec::new(), input_shares[1].get_encoded()?).get_encoded()?, + &aad, + )?; + + let report = Report::new( + report_metadata, + encoded_public_share, + leader_encrypted_input_share, + helper_encrypted_input_share, + ); + + let url = task + .leader_aggregator_endpoint() + .join(&format!("tasks/{task_id}/reports")) + .unwrap(); + retry_http_request(http_request_exponential_backoff(), || async { + http_client + .put(url.clone()) + .header(CONTENT_TYPE, Report::MEDIA_TYPE) + .body(report.get_encoded().unwrap()) + .send() + .await + }) + .await?; + + Ok(()) +} + +async fn aggregator_hpke_config( + endpoint: &Url, + http_client: &reqwest::Client, +) -> Result { + let response = retry_http_request(http_request_exponential_backoff(), || async { + http_client + .get(endpoint.join("hpke_config").unwrap()) + .send() + .await + }) + .await?; + let status = response.status(); + if !status.is_success() { + return Err(janus_client::Error::Http(Box::new( + HttpErrorResponse::from(status), + ))); + } + + let list = HpkeConfigList::get_decoded(response.body())?; + + Ok(list.hpke_configs()[0].clone()) +} + +fn check_aggregate_results_valid( + map: &HashMap, Q>>, + state: &State, +) -> bool { + for collection in map.values() { + let result = collection.aggregate_result(); + if result.iter().any(|value| *value != 0 && *value != 1) { + error!(?result, "bad aggregate result"); + return false; + } + if result[state.next_measurement..] + .iter() + .any(|value| *value != 0) + { + error!( + ?result, + num_measurements = state.next_measurement, + "bad aggregate result, unexpected 1 with no corresponding report" + ); + return false; + } + } + true +} + +/// Aggressive exponential backoff parameters for this local-only test. Due to fault injection +/// operations, we will often be hitting `max_elapsed_time`, so this value needs to be very low. +pub(super) fn http_request_exponential_backoff() -> ExponentialBackoff { + ExponentialBackoff { + initial_interval: StdDuration::from_millis(10), + max_interval: StdDuration::from_millis(50), + multiplier: 2.0, + max_elapsed_time: Some(StdDuration::from_millis(250)), + ..Default::default() + } +} diff --git a/integration_tests/tests/integration/simulation/setup.rs b/integration_tests/tests/integration/simulation/setup.rs new file mode 100644 index 000000000..1e4e3b561 --- /dev/null +++ b/integration_tests/tests/integration/simulation/setup.rs @@ -0,0 +1,314 @@ +use std::{ + net::{Ipv4Addr, SocketAddr}, + sync::Arc, + time::Duration as StdDuration, +}; + +use futures::future::BoxFuture; +use janus_aggregator::{ + aggregator::{ + self, + aggregation_job_creator::AggregationJobCreator, + aggregation_job_driver::AggregationJobDriver, + collection_job_driver::{CollectionJobDriver, RetryStrategy}, + garbage_collector::GarbageCollector, + http_handlers::aggregator_handler, + key_rotator::KeyRotator, + Config as AggregatorConfig, + }, + cache::{ + GlobalHpkeKeypairCache, TASK_AGGREGATOR_CACHE_DEFAULT_CAPACITY, + TASK_AGGREGATOR_CACHE_DEFAULT_TTL, + }, +}; +use janus_aggregator_core::{ + datastore::{ + self, + models::{AcquiredAggregationJob, AcquiredCollectionJob, Lease}, + test_util::{ephemeral_datastore, EphemeralDatastore}, + Datastore, + }, + task::{ + test_util::{Task, TaskBuilder}, + QueryType, + }, +}; +use janus_client::{default_http_client, Client}; +use janus_collector::Collector; +use janus_core::{test_util::runtime::TestRuntime, time::MockClock, Runtime}; +use prio::vdaf::prio3::Prio3Histogram; +use tokio::net::TcpListener; + +use crate::simulation::{ + model::Input, + proxy::{FaultInjector, FaultInjectorHandler, InspectHandler, InspectMonitor}, + run::{http_request_exponential_backoff, State}, +}; + +// Labels for TestRuntimeManager. +static LEADER_AGGREGATOR_REPORT_WRITER: &str = "leader_aggregator_report_writer"; +static HELPER_AGGREGATOR_REPORT_WRITER: &str = "helper_aggregator_report_writer"; +static LEADER_AGGREGATOR_SERVER: &str = "leader_aggregator_server"; +static HELPER_AGGREGATOR_SERVER: &str = "helper_aggregator_server"; + +const BATCH_AGGREGATION_SHARD_COUNT: usize = 32; +const TASK_COUNTER_SHARD_COUNT: u64 = 128; + +pub(super) struct SimulationAggregator { + pub(super) _ephemeral_datastore: EphemeralDatastore, + pub(super) datastore: Arc>, + pub(super) socket_address: SocketAddr, + pub(super) fault_injector: FaultInjector, + pub(super) inspect_monitor: InspectMonitor, +} + +impl SimulationAggregator { + pub(super) async fn new( + report_writer_runtime: TestRuntime, + server_runtime: TestRuntime, + state: &State, + ) -> Self { + let ephemeral_datastore = ephemeral_datastore().await; + let datastore = Arc::new(ephemeral_datastore.datastore(state.clock.clone()).await); + + let aggregator_config = AggregatorConfig { + // Set this to 1 because report uploads will be serialized. + max_upload_batch_size: 1, + max_upload_batch_write_delay: StdDuration::from_secs(0), + batch_aggregation_shard_count: 32, + task_counter_shard_count: TASK_COUNTER_SHARD_COUNT, + global_hpke_configs_refresh_interval: GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + hpke_config_signing_key: None, + // We only support Taskprov on the helper side, so leave it disabled. + taskprov_config: Default::default(), + task_cache_ttl: TASK_AGGREGATOR_CACHE_DEFAULT_TTL, + task_cache_capacity: TASK_AGGREGATOR_CACHE_DEFAULT_CAPACITY, + log_forbidden_mutations: None, + require_global_hpke_keys: false, + }; + + let aggregator_handler = aggregator_handler( + Arc::clone(&datastore), + state.clock.clone(), + report_writer_runtime, + &state.meter, + aggregator_config, + ) + .await + .unwrap(); + + let inspect_handler = InspectHandler::new(aggregator_handler); + let inspect_monitor = inspect_handler.monitor(); + + let fault_injector_handler = FaultInjectorHandler::new(inspect_handler); + let fault_injector = fault_injector_handler.controller(); + + let server = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap(); + let socket_address = server.local_addr().unwrap(); + let aggregator_future = trillium_tokio::config() + .with_stopper(state.stopper.clone()) + .without_signals() + .with_prebound_server(server) + .run_async(fault_injector_handler); + server_runtime.spawn(aggregator_future); + + SimulationAggregator { + _ephemeral_datastore: ephemeral_datastore, + datastore, + socket_address, + fault_injector, + inspect_monitor, + } + } +} + +type JobAcquirerCallback = + Box BoxFuture<'static, Result>, datastore::Error>>>; +type JobStepperCallback = + Box) -> BoxFuture<'static, Result<(), aggregator::Error>>>; + +pub(super) struct Components { + pub(super) leader: SimulationAggregator, + pub(super) helper: SimulationAggregator, + pub(super) http_client: reqwest::Client, + pub(super) client: Client, + pub(super) leader_garbage_collector: GarbageCollector, + pub(super) helper_garbage_collector: GarbageCollector, + pub(super) leader_key_rotator: KeyRotator, + pub(super) helper_key_rotator: KeyRotator, + pub(super) aggregation_job_creator: Arc>, + pub(super) aggregation_job_driver_acquirer_cb: JobAcquirerCallback, + pub(super) aggregation_job_driver_stepper_cb: JobStepperCallback, + pub(super) collection_job_driver_acquirer_cb: JobAcquirerCallback, + pub(super) collection_job_driver_stepper_cb: JobStepperCallback, + pub(super) collector: Collector, +} + +impl Components { + pub(super) async fn setup(input: &Input, state: &mut State) -> (Self, Task) { + let leader = SimulationAggregator::new( + state + .runtime_manager + .with_label(LEADER_AGGREGATOR_REPORT_WRITER), + state.runtime_manager.with_label(LEADER_AGGREGATOR_SERVER), + state, + ) + .await; + + let helper = SimulationAggregator::new( + state + .runtime_manager + .with_label(HELPER_AGGREGATOR_REPORT_WRITER), + state.runtime_manager.with_label(HELPER_AGGREGATOR_SERVER), + state, + ) + .await; + + let query_type = if input.is_fixed_size { + QueryType::FixedSize { + max_batch_size: input.config.max_batch_size, + batch_time_window_size: input.config.batch_time_window_size, + } + } else { + QueryType::TimeInterval + }; + let task = TaskBuilder::new(query_type, state.vdaf_instance.clone()) + .with_leader_aggregator_endpoint( + format!("http://{}/", leader.socket_address) + .parse() + .unwrap(), + ) + .with_helper_aggregator_endpoint( + format!("http://{}/", helper.socket_address) + .parse() + .unwrap(), + ) + .with_time_precision(input.config.time_precision) + .with_min_batch_size(input.config.min_batch_size) + .with_report_expiry_age(input.config.report_expiry_age) + .build(); + let leader_task = task.leader_view().unwrap(); + let helper_task = task.helper_view().unwrap(); + leader + .datastore + .put_aggregator_task(&leader_task) + .await + .unwrap(); + helper + .datastore + .put_aggregator_task(&helper_task) + .await + .unwrap(); + + let http_client = default_http_client().unwrap(); + let client = Client::builder( + *task.id(), + task.leader_aggregator_endpoint().clone(), + task.helper_aggregator_endpoint().clone(), + *task.time_precision(), + state.vdaf.clone(), + ) + .with_http_client(http_client.clone()) + .build() + .await + .unwrap(); + + let leader_garbage_collector = GarbageCollector::new( + Arc::clone(&leader.datastore), + &state.meter, + 100, + 100, + 100, + 1, + None, + ); + + let helper_garbage_collector = GarbageCollector::new( + Arc::clone(&helper.datastore), + &state.meter, + 100, + 100, + 100, + 1, + None, + ); + + let leader_key_rotator = KeyRotator::new(Arc::clone(&leader.datastore), Default::default()); + + let helper_key_rotator = KeyRotator::new(Arc::clone(&helper.datastore), Default::default()); + + let aggregation_job_creator = Arc::new(AggregationJobCreator::new( + Arc::clone(&leader.datastore), + state.meter.clone(), + BATCH_AGGREGATION_SHARD_COUNT.try_into().unwrap(), + StdDuration::from_secs(0), // unused + StdDuration::from_secs(0), // unused + input.config.min_aggregation_job_size, + input.config.max_aggregation_job_size, + 5000, + )); + + let aggregation_job_driver = Arc::new(AggregationJobDriver::new( + reqwest::Client::new(), + http_request_exponential_backoff(), + &state.meter, + BATCH_AGGREGATION_SHARD_COUNT.try_into().unwrap(), + TASK_COUNTER_SHARD_COUNT, + )); + let aggregation_job_driver_acquirer_cb = Box::new( + aggregation_job_driver.make_incomplete_job_acquirer_callback( + Arc::clone(&leader.datastore), + StdDuration::from_secs(600), + ), + ); + let aggregation_job_driver_stepper_cb = Box::new( + aggregation_job_driver.make_job_stepper_callback(Arc::clone(&leader.datastore), 2), + ); + + let collection_job_driver = Arc::new(CollectionJobDriver::new( + reqwest::Client::new(), + http_request_exponential_backoff(), + &state.meter, + BATCH_AGGREGATION_SHARD_COUNT.try_into().unwrap(), + RetryStrategy::new(StdDuration::ZERO, StdDuration::ZERO, 1.0).unwrap(), + )); + let collection_job_driver_acquirer_cb = + Box::new(collection_job_driver.make_incomplete_job_acquirer_callback( + Arc::clone(&leader.datastore), + StdDuration::from_secs(600), + )); + let collection_job_driver_stepper_cb = Box::new( + collection_job_driver.make_job_stepper_callback(Arc::clone(&leader.datastore), 2), + ); + + let collector = Collector::builder( + *task.id(), + task.leader_aggregator_endpoint().clone(), + task.collector_auth_token().clone(), + task.collector_hpke_keypair().clone(), + state.vdaf.clone(), + ) + .build() + .unwrap(); + + ( + Self { + leader, + helper, + http_client, + client, + leader_garbage_collector, + helper_garbage_collector, + leader_key_rotator, + helper_key_rotator, + aggregation_job_creator, + aggregation_job_driver_acquirer_cb, + aggregation_job_driver_stepper_cb, + collection_job_driver_acquirer_cb, + collection_job_driver_stepper_cb, + collector, + }, + task, + ) + } +}