Skip to content

Commit

Permalink
key rotator: allow running it as part of aggregator process (#3228)
Browse files Browse the repository at this point in the history
* key rotator: allow running it as part of aggregator process

* Minor amount of logging

* Comments are good

* PR feedback
  • Loading branch information
inahga authored Jun 20, 2024
1 parent fe1749f commit 901f8c5
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 26 deletions.
105 changes: 84 additions & 21 deletions aggregator/src/binaries/aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::{
aggregator::{self, http_handlers::aggregator_handler},
aggregator::{
self,
http_handlers::aggregator_handler,
key_rotator::{deserialize_hpke_key_rotator_config, HpkeKeyRotatorConfig, KeyRotator},
},
binaries::garbage_collector::run_garbage_collector,
binary_utils::{setup_server, BinaryContext, BinaryOptions, CommonBinaryOptions},
cache::{
Expand All @@ -24,11 +28,10 @@ use serde::{de, Deserialize, Deserializer, Serialize};
use std::{
future::{ready, Future},
path::PathBuf,
pin::Pin,
};
use std::{iter::Iterator, net::SocketAddr, sync::Arc, time::Duration};
use tokio::{join, sync::watch};
use tracing::info;
use tokio::{spawn, sync::watch, time::interval, try_join};
use tracing::{error, info};
use trillium::Handler;
use trillium_router::router;
use url::Url;
Expand Down Expand Up @@ -78,19 +81,41 @@ async fn run_aggregator(
None,
);

let garbage_collector_future = {
let garbage_collector_handle = {
let datastore = Arc::clone(&datastore);
let gc_config = config.garbage_collection.take();
let meter = meter.clone();
let stopper = stopper.clone();
async move {
spawn(async move {
if let Some(gc_config) = gc_config {
info!("Running garbage collector");
run_garbage_collector(datastore, gc_config, meter, stopper).await;
}
}
})
};

let key_rotator_handle = {
let datastore = Arc::clone(&datastore);
let config = config.key_rotator.take();
let stopper = stopper.clone();
spawn(async move {
if let Some(config) = config {
info!("Running key rotator");
let key_rotator = KeyRotator::new(datastore, config.hpke);
let mut interval = interval(Duration::from_secs(config.frequency_s));
// Note that `interval` fires immediately at first, so the key rotator runs
// immediately on boot. This takes care of bootstrapping keys on the first run of
// Janus.
while stopper.stop_future(interval.tick()).await.is_some() {
if let Err(err) = key_rotator.run().await {
error!(?err, "key rotator error");
}
}
}
})
};

let aggregator_api_future: Pin<Box<dyn Future<Output = ()> + Send + 'static>> =
let aggregator_api_handle =
match build_aggregator_api_handler(&options, &config, &datastore, &meter)? {
Some((handler, config)) => {
if let Some(listen_address) = config.listen_address {
Expand All @@ -103,7 +128,7 @@ async fn run_aggregator(

info!(?aggregator_api_bound_address, "Running aggregator API");

Box::pin(aggregator_api_server)
spawn(aggregator_api_server)
} else if let Some(path_prefix) = &config.path_prefix {
// Create a Trillium handler under the requested path prefix, which we'll add to
// the DAP API handler in the setup_server call below
Expand All @@ -115,27 +140,29 @@ async fn run_aggregator(
// Append wildcard so that this handler will match anything under the prefix
let path_prefix = format!("{path_prefix}/*");
handlers.1 = Some(router().all(path_prefix, handler));
Box::pin(ready(()))
spawn(ready(()))
} else {
unreachable!("the configuration should not have deserialized to this state")
}
}
None => Box::pin(ready(())),
None => spawn(ready(())),
};

let (aggregator_bound_address, aggregator_server) =
setup_server(config.listen_address, stopper.clone(), handlers)
.await
.context("failed to create aggregator server")?;
sender.send_replace(Some(aggregator_bound_address));
let aggregator_server_handle = spawn(aggregator_server);

info!(?aggregator_bound_address, "Running aggregator");

join!(
aggregator_server,
garbage_collector_future,
aggregator_api_future
);
try_join!(
aggregator_server_handle,
garbage_collector_handle,
key_rotator_handle,
aggregator_api_handle
)?;
Ok(())
}

Expand Down Expand Up @@ -318,6 +345,10 @@ pub struct Config {
#[serde(default)]
pub garbage_collection: Option<GarbageCollectorConfig>,

/// Run the key rotator in this binary.
#[serde(default)]
pub key_rotator: Option<KeyRotatorConfig>,

/// Address on which this server should listen for connections to the DAP aggregator API and
/// serve its API endpoints.
pub listen_address: SocketAddr,
Expand Down Expand Up @@ -365,6 +396,15 @@ pub struct Config {
pub task_cache_capacity: Option<u64>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct KeyRotatorConfig {
/// How frequently the key rotator is run, in seconds.
pub frequency_s: u64,

#[serde(deserialize_with = "deserialize_hpke_key_rotator_config")]
pub hpke: HpkeKeyRotatorConfig,
}

fn default_task_counter_shard_count() -> u64 {
32
}
Expand Down Expand Up @@ -460,10 +500,11 @@ pub(crate) fn parse_pem_ec_private_key(ec_private_key_pem: &str) -> Result<Ecdsa

#[cfg(test)]
mod tests {
use super::{AggregatorApi, Config, GarbageCollectorConfig, Options};
use super::{AggregatorApi, Config, GarbageCollectorConfig, KeyRotatorConfig, Options};
use crate::{
aggregator::{
self,
key_rotator::HpkeKeyRotatorConfig,
test_util::{hpke_config_signing_key, HPKE_CONFIG_SIGNING_KEY_PEM},
},
config::{
Expand All @@ -478,16 +519,18 @@ mod tests {
};
use assert_matches::assert_matches;
use clap::CommandFactory;
use janus_core::test_util::roundtrip_encoding;
use janus_core::{hpke::HpkeCiphersuite, test_util::roundtrip_encoding};
use janus_messages::{Duration, HpkeAeadId, HpkeKdfId, HpkeKemId};
use rand::random;
use ring::{
rand::SystemRandom,
signature::{KeyPair, UnparsedPublicKey, ECDSA_P256_SHA256_ASN1},
};
use std::{
collections::HashSet,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::PathBuf,
time::Duration,
time::Duration as StdDuration,
};

#[test]
Expand Down Expand Up @@ -518,6 +561,26 @@ mod tests {
tasks_per_tx: 15,
concurrent_tx_limit: Some(23),
}),
key_rotator: Some(KeyRotatorConfig {
frequency_s: random(),
hpke: HpkeKeyRotatorConfig {
pending_duration: Duration::from_seconds(random()),
active_duration: Duration::from_seconds(random()),
expired_duration: Duration::from_seconds(random()),
ciphersuites: HashSet::from([
HpkeCiphersuite::new(
HpkeKemId::P256HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
),
HpkeCiphersuite::new(
HpkeKemId::P521HkdfSha512,
HpkeKdfId::HkdfSha512,
HpkeAeadId::Aes256Gcm,
),
]),
},
}),
aggregator_api: Some(aggregator_api),
common_config: CommonConfig {
database: generate_db_config(),
Expand Down Expand Up @@ -674,7 +737,7 @@ mod tests {
.unwrap(),
&aggregator::Config {
max_upload_batch_size: 100,
max_upload_batch_write_delay: Duration::from_millis(250),
max_upload_batch_write_delay: StdDuration::from_millis(250),
batch_aggregation_shard_count: 32,
taskprov_config: TaskprovConfig::default(),
hpke_config_signing_key: Some(hpke_config_signing_key()),
Expand Down Expand Up @@ -835,7 +898,7 @@ mod tests {
.unwrap(),
&aggregator::Config {
max_upload_batch_size: 100,
max_upload_batch_write_delay: Duration::from_millis(250),
max_upload_batch_write_delay: StdDuration::from_millis(250),
batch_aggregation_shard_count: 32,
taskprov_config: TaskprovConfig::default(),
..Default::default()
Expand Down
38 changes: 33 additions & 5 deletions aggregator/tests/integration/graceful_shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use janus_aggregator::{
aggregator::key_rotator::HpkeKeyRotatorConfig,
binaries::{
aggregation_job_creator::Config as AggregationJobCreatorConfig,
aggregation_job_driver::Config as AggregationJobDriverConfig,
aggregator::{AggregatorApi, Config as AggregatorConfig, GarbageCollectorConfig},
aggregator::{
AggregatorApi, Config as AggregatorConfig, GarbageCollectorConfig, KeyRotatorConfig,
},
collection_job_driver::Config as CollectionJobDriverConfig,
garbage_collector::Config as GarbageCollectorBinaryConfig,
},
Expand All @@ -23,15 +26,20 @@ use janus_aggregator_core::{
datastore::test_util::ephemeral_datastore,
task::{test_util::TaskBuilder, QueryType},
};
use janus_core::{test_util::install_test_trace_subscriber, time::RealClock, vdaf::VdafInstance};
use janus_core::{
hpke::HpkeCiphersuite, test_util::install_test_trace_subscriber, time::RealClock,
vdaf::VdafInstance,
};
use janus_messages::{Duration, HpkeAeadId, HpkeKdfId, HpkeKemId};
use reqwest::Url;
use serde::Serialize;
use std::{
collections::HashSet,
future::Future,
io::{ErrorKind, Write},
net::{Ipv4Addr, SocketAddr},
process::{Child, Command, Stdio},
time::Instant,
time::{Duration as StdDuration, Instant},
};
use tokio::{
io::{AsyncBufReadExt, BufReader},
Expand Down Expand Up @@ -63,7 +71,7 @@ async fn wait_for_server(addr: SocketAddr) -> Result<(), Timeout> {
for _ in 0..30 {
match TcpStream::connect(addr).await {
Ok(_) => return Ok(()),
Err(_) => sleep(std::time::Duration::from_millis(500)).await,
Err(_) => sleep(StdDuration::from_millis(500)).await,
}
}
Err(Timeout)
Expand Down Expand Up @@ -215,7 +223,7 @@ async fn graceful_shutdown<C: BinaryConfig + Serialize>(binary_name: &str, mut c
// Confirm that the binary under test shuts down promptly.
let start = Instant::now();
let (mut child, child_exit_status_res) = spawn_blocking(move || {
let result = child.wait_timeout(std::time::Duration::from_secs(15));
let result = child.wait_timeout(StdDuration::from_secs(15));
(child, result)
})
.await
Expand Down Expand Up @@ -273,6 +281,26 @@ async fn aggregator_shutdown() {
tasks_per_tx: 1,
concurrent_tx_limit: None,
}),
key_rotator: Some(KeyRotatorConfig {
frequency_s: 60 * 60 * 6,
hpke: HpkeKeyRotatorConfig {
pending_duration: Duration::from_seconds(60),
active_duration: Duration::from_seconds(60 * 60 * 24),
expired_duration: Duration::from_seconds(60 * 60 * 24),
ciphersuites: HashSet::from([
HpkeCiphersuite::new(
HpkeKemId::P256HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
),
HpkeCiphersuite::new(
HpkeKemId::P521HkdfSha512,
HpkeKdfId::HkdfSha512,
HpkeAeadId::Aes256Gcm,
),
]),
},
}),
listen_address: aggregator_listen_address,
aggregator_api: Some(AggregatorApi {
listen_address: Some(aggregator_api_listen_address),
Expand Down
28 changes: 28 additions & 0 deletions docs/samples/advanced_config/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,31 @@ garbage_collection:
# The maximum number of collection jobs (& related artifacts), per task, to delete in a single run
# of the garbage collector.
collection_limit: 50

# Configuration for key rotator. Allows running the key rotator as part of the aggregator process.
# If omitted, you should run the key rotator as a separate cronjob.
key_rotator:
# How frequently to run the key rotator, in seconds. Required.
frequency_s: 3600

# Rotation policy for global HPKE keys.
hpke:
# How long keys remains pending before they're promoted to active. Should
# be greater than the global HPKE keypair cache refresh rate. Defaults to
# 1 hour.
pending_duration_s: 3600

# The TTL of keys. Defaults to 4 weeks.
active_duration_s: 2419200

# How long keys can be expired before being deleted. Should be greater than
# how long clients cache HPKE keys. Defaults to 1 week.
expired_duration_s: 604800

# The set of keys to manage, identified by ciphersuite. At least one is
# required. Each entry represents a key with a particular ciphersuite.
ciphersuite:
# Defaults to a key with these algorithms.
- kem_id: P521HkdfSha512
kdf_id: HkdfSha512
aead_id: Aes256Gcm
1 change: 1 addition & 0 deletions integration_tests/src/janus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ impl JanusInProcess {
common_config: common_config.clone(),
taskprov_config: TaskprovConfig::default(),
garbage_collection: None,
key_rotator: None,
listen_address: (Ipv4Addr::LOCALHOST, 0).into(),
aggregator_api: None,
max_upload_batch_size: 100,
Expand Down

0 comments on commit 901f8c5

Please sign in to comment.