diff --git a/deploy-tool/main.go b/deploy-tool/main.go index f78a45633..5c0a0e3da 100644 --- a/deploy-tool/main.go +++ b/deploy-tool/main.go @@ -61,6 +61,10 @@ type SpecificManifest struct { // IngestionBucket is the region+name of the bucket that the data share // processor which owns the manifest reads ingestion batches from. IngestionBucket string `json:"ingestion-bucket"` + // IngestionIdentity is the ARN of the AWS IAM role that should be assumed + // by an ingestion server to write to this data share processor's ingestion + // bucket, if the ingestor does not have an AWS account of their own. + IngestionIdentity string `json:"ingestion-identity"` // PeerValidationBucket is the region+name of the bucket that the data share // processor which owns the manifest reads peer validation batches from. PeerValidationBucket string `json:"peer-validation-bucket"` diff --git a/facilitator/src/aggregation.rs b/facilitator/src/aggregation.rs index 18ea4b5fd..9b58181ba 100644 --- a/facilitator/src/aggregation.rs +++ b/facilitator/src/aggregation.rs @@ -23,6 +23,7 @@ pub struct BatchAggregator<'a> { ingestion_transport: &'a mut VerifiableAndDecryptableTransport, aggregation_batch: BatchWriter<'a, SumPart, InvalidPacket>, share_processor_signing_key: &'a BatchSigningKey, + total_individual_clients: i64, } impl<'a> BatchAggregator<'a> { @@ -57,6 +58,7 @@ impl<'a> BatchAggregator<'a> { &mut *aggregation_transport.transport, ), share_processor_signing_key: &aggregation_transport.batch_signing_key, + total_individual_clients: 0, }) } @@ -115,8 +117,6 @@ impl<'a> BatchAggregator<'a> { .map(|f| u32::from(*f) as i64) .collect(); - let total_individual_clients = accumulator_server.total_shares().len() as i64; - let sum_signature = self.aggregation_batch.put_header( &SumPart { batch_uuids: batch_ids.iter().map(|pair| pair.0).collect(), @@ -130,7 +130,7 @@ impl<'a> BatchAggregator<'a> { aggregation_start_time: self.aggregation_start.timestamp_millis(), aggregation_end_time: self.aggregation_end.timestamp_millis(), packet_file_digest: invalid_packets_digest.as_ref().to_vec(), - total_individual_clients, + total_individual_clients: self.total_individual_clients, }, &self.share_processor_signing_key.key, )?; @@ -282,6 +282,7 @@ impl<'a> BatchAggregator<'a> { if !valid { invalid_uuids.push(peer_validation_packet.uuid); } + self.total_individual_clients += 1; did_aggregate_shares = true; break; } diff --git a/facilitator/src/batch.rs b/facilitator/src/batch.rs index 2a4272eb9..18ae50995 100644 --- a/facilitator/src/batch.rs +++ b/facilitator/src/batch.rs @@ -127,6 +127,10 @@ impl<'a, H: Header, P: Packet> BatchReader<'a, H, P> { } } + pub fn path(&self) -> String { + self.transport.path() + } + /// Return the parsed header from this batch, but only if its signature is /// valid. The signature is checked by getting the key_identifier value from /// the signature message, using that to obtain a public key from the @@ -222,6 +226,10 @@ impl<'a, H: Header, P: Packet> BatchWriter<'a, H, P> { } } + pub fn path(&self) -> String { + self.transport.path() + } + /// Encode the provided header into Avro, sign that representation with the /// provided key and write the header into the batch. Returns the signature /// on success. diff --git a/facilitator/src/intake.rs b/facilitator/src/intake.rs index 03315b6eb..33c7e7aac 100644 --- a/facilitator/src/intake.rs +++ b/facilitator/src/intake.rs @@ -4,8 +4,9 @@ use crate::{ transport::{SignableTransport, VerifiableAndDecryptableTransport}, BatchSigningKey, Error, }; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, ensure, Context, Result}; use chrono::NaiveDateTime; +use log::info; use prio::{encrypt::PrivateKey, finite_field::Field, server::Server}; use ring::signature::UnparsedPublicKey; use std::{collections::HashMap, convert::TryFrom, iter::Iterator}; @@ -15,8 +16,8 @@ use uuid::Uuid; /// sent by the ingestion server and emitting validation shares to the other /// share processor. pub struct BatchIntaker<'a> { - ingestion_batch: BatchReader<'a, IngestionHeader, IngestionDataSharePacket>, - ingestor_public_keys: &'a HashMap>>, + intake_batch: BatchReader<'a, IngestionHeader, IngestionDataSharePacket>, + intake_public_keys: &'a HashMap>>, packet_decryption_keys: &'a Vec, peer_validation_batch: BatchWriter<'a, ValidationHeader, ValidationPacket>, peer_validation_batch_signing_key: &'a BatchSigningKey, @@ -36,11 +37,11 @@ impl<'a> BatchIntaker<'a> { is_first: bool, ) -> Result> { Ok(BatchIntaker { - ingestion_batch: BatchReader::new( + intake_batch: BatchReader::new( Batch::new_ingestion(aggregation_name, batch_id, date), &mut *ingestion_transport.transport.transport, ), - ingestor_public_keys: &ingestion_transport.transport.batch_signing_public_keys, + intake_public_keys: &ingestion_transport.transport.batch_signing_public_keys, packet_decryption_keys: &ingestion_transport.packet_decryption_keys, peer_validation_batch: BatchWriter::new( Batch::new_validation(aggregation_name, batch_id, date, is_first), @@ -60,13 +61,18 @@ impl<'a> BatchIntaker<'a> { /// and packet file, then computes validation shares and sends them to the /// peer share processor. pub fn generate_validation_share(&mut self) -> Result<()> { - let ingestion_header = self.ingestion_batch.header(self.ingestor_public_keys)?; - if ingestion_header.bins <= 0 { - return Err(anyhow!( - "invalid bins/dimension value {}", - ingestion_header.bins - )); - } + info!( + "processing intake from {} and saving validity to {} and {}", + self.intake_batch.path(), + self.own_validation_batch.path(), + self.peer_validation_batch.path() + ); + let ingestion_header = self.intake_batch.header(self.intake_public_keys)?; + ensure!( + ingestion_header.bins > 0, + "invalid bin count {}", + ingestion_header.bins + ); // Ideally, we would use the encryption_key_id in the ingestion packet // to figure out which private key to use for decryption, but that field @@ -82,7 +88,7 @@ impl<'a> BatchIntaker<'a> { // Read all the ingestion packets, generate a verification message for // each, and write them to the validation batch. let mut ingestion_packet_reader = - self.ingestion_batch.packet_file_reader(&ingestion_header)?; + self.intake_batch.packet_file_reader(&ingestion_header)?; let packet_file_digest = self.peer_validation_batch.multi_packet_file_writer( vec![&mut self.own_validation_batch], diff --git a/facilitator/src/manifest.rs b/facilitator/src/manifest.rs index 13bb6f476..234559b63 100644 --- a/facilitator/src/manifest.rs +++ b/facilitator/src/manifest.rs @@ -96,6 +96,10 @@ pub struct SpecificManifest { /// Region and name of the ingestion S3 bucket owned by this data share /// processor. ingestion_bucket: String, + // The ARN of the AWS IAM role that should be assumed by an ingestion server + // to write to this data share processor's ingestion bucket, if the ingestor + // does not have an AWS account of their own. + ingestion_identity: String, /// Region and name of the peer validation S3 bucket owned by this data /// share processor. peer_validation_bucket: String, @@ -171,8 +175,13 @@ struct IngestionServerIdentity { aws_iam_entity: Option, /// The numeric identifier of the GCP service account that this ingestion /// server uses to authenticate via OIDC identity federation to access - /// ingestion buckets. - google_service_account: Option, + /// ingestion buckets. While this field's value is a number, facilitator + /// treats it as an opaque string. + google_service_account: Option, + /// The email address of the GCP service account that this ingestion server + /// uses to authenticate via OIDC identity federation to access ingestion + /// buckets. + gcp_service_account_email: Option, } /// Represents an ingestion server's global manifest. @@ -480,6 +489,7 @@ mod tests { }} }}, "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": "arn:aws:iam:something:fake", "peer-validation-bucket": "us-west-1/validation" }} "#, @@ -509,8 +519,9 @@ mod tests { format: 0, batch_signing_public_keys: expected_batch_keys, packet_encryption_certificates: expected_packet_encryption_certificates, - ingestion_bucket: "us-west-1/ingestion".to_string(), - peer_validation_bucket: "us-west-1/validation".to_string(), + ingestion_bucket: "us-west-1/ingestion".to_owned(), + ingestion_identity: "arn:aws:iam:something:fake".to_owned(), + peer_validation_bucket: "us-west-1/validation".to_owned(), }; assert_eq!(manifest, expected_manifest); let batch_signing_keys = manifest.batch_signing_public_keys().unwrap(); @@ -559,6 +570,7 @@ mod tests { } }, "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": "arn:aws:iam:something:fake", "peer-validation-bucket": "us-west-1/validation" } "#, @@ -578,6 +590,7 @@ mod tests { } }, "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": "arn:aws:iam:something:fake", "peer-validation-bucket": "us-west-1/validation" } "#, @@ -597,9 +610,30 @@ mod tests { } }, "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": "arn:aws:iam:something:fake", "peer-validation-bucket": "us-west-1/validation" } "#, + // Role ARN with wrong type + r#" +{ + "format": 0, + "packet-encryption-certificates": { + "fake-key-1": { + "certificate": "who cares" + } + }, + "batch-signing-public-keys": { + "fake-key-2": { + "expiration": "", + "public-key": "-----BEGIN PUBLIC KEY-----\nfoo\n-----END PUBLIC KEY-----" + } + }, + "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": 1, + "peer-validation-bucket": "us-west-1/validation" +} +"#, ]; for invalid_manifest in &invalid_manifests { @@ -627,6 +661,7 @@ mod tests { } }, "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": "arn:aws:iam:something:fake", "peer-validation-bucket": "us-west-1/validation" } "#, @@ -646,6 +681,7 @@ mod tests { } }, "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": "arn:aws:iam:something:fake", "peer-validation-bucket": "us-west-1/validation" } "#, @@ -665,6 +701,7 @@ mod tests { } }, "ingestion-bucket": "us-west-1/ingestion", + "ingestion-identity": "arn:aws:iam:something:fake", "peer-validation-bucket": "us-west-1/validation" } "#, @@ -696,7 +733,8 @@ mod tests { { "format": 0, "server-identity": { - "google-service-account": 123456789012345 + "google-service-account": "112310747466759665351", + "gcp-service-account-email": "foo@bar.com" }, "batch-signing-public-keys": { "key-identifier-2": { @@ -729,7 +767,11 @@ mod tests { assert_eq!(manifest.server_identity.aws_iam_entity, None); assert_eq!( manifest.server_identity.google_service_account, - Some(123456789012345) + Some("112310747466759665351".to_owned()) + ); + assert_eq!( + manifest.server_identity.gcp_service_account_email, + Some("foo@bar.com".to_owned()) ); let batch_signing_public_keys = manifest.batch_signing_public_keys().unwrap(); batch_signing_public_keys.get("key-identifier-2").unwrap(); diff --git a/facilitator/src/transport.rs b/facilitator/src/transport.rs index e4240ed4d..7a81db218 100644 --- a/facilitator/src/transport.rs +++ b/facilitator/src/transport.rs @@ -64,4 +64,6 @@ pub trait Transport { /// Returns an std::io::Write instance into which the contents of the value /// may be written. fn put(&mut self, key: &str) -> Result>; + + fn path(&self) -> String; } diff --git a/facilitator/src/transport/gcs.rs b/facilitator/src/transport/gcs.rs index 97417014a..3667f7b72 100644 --- a/facilitator/src/transport/gcs.rs +++ b/facilitator/src/transport/gcs.rs @@ -8,7 +8,7 @@ use chrono::{prelude::Utc, DateTime, Duration}; use log::info; use serde::Deserialize; use std::{ - io, + fmt, io, io::{Read, Write}, }; @@ -17,7 +17,6 @@ const DEFAULT_OAUTH_TOKEN_URL: &str = "http://metadata.google.internal:80/computeMetadata/v1/instance/service-accounts/default/token"; /// A wrapper around an Oauth token and its expiration date. -#[derive(Debug)] struct OauthToken { token: String, expiration: DateTime, @@ -33,7 +32,7 @@ impl OauthToken { /// Represents the response from a GET request to the GKE metadata service's /// service account token endpoint. Structure is derived from empirical /// observation of the JSON scraped from inside a GKE job. -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Deserialize, PartialEq)] struct MetadataServiceTokenResponse { access_token: String, expires_in: i64, @@ -43,7 +42,7 @@ struct MetadataServiceTokenResponse { /// Represents the response from a POST request to the GCP IAM service's /// generateAccessToken endpoint. /// https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/generateAccessToken -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] struct GenerateAccessTokenResponse { access_token: String, @@ -53,30 +52,44 @@ struct GenerateAccessTokenResponse { /// OauthTokenProvider manages a default service account Oauth token (i.e. the /// one for a GCP service account mapped to a Kubernetes service account) and an /// Oauth token used to impersonate another service account. -#[derive(Debug)] struct OauthTokenProvider { /// Holds the service account email to impersonate, if one was provided to /// OauthTokenProvider::new. - service_account_to_impersonate: Option, + account_to_impersonate: Option, /// This field is None after instantiation and is Some after the first /// successful request for a token for the default service account, though /// the contained token may be expired. - default_service_account_oauth_token: Option, + default_account_token: Option, /// This field is None after instantiation and is Some after the first /// successful request for a token for the impersonated service account, /// though the contained token may be expired. This will always be None if - /// service_account_to_impersonate is None. - impersonated_service_account_oauth_token: Option, + /// account_to_impersonate is None. + impersonated_account_token: Option, } +impl fmt::Debug for OauthTokenProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OauthTokenProvider") + .field("account_to_impersonate", &self.account_to_impersonate) + .field( + "default_account_token", + &self.default_account_token.as_ref().map(|_| "redacted"), + ) + .field( + "impersonated_account_token", + &self.default_account_token.as_ref().map(|_| "redacted"), + ) + .finish() + } +} impl OauthTokenProvider { /// Creates a token provider which can impersonate the specified service /// account. - fn new(service_account_to_impersonate: Option) -> OauthTokenProvider { + fn new(account_to_impersonate: Option) -> OauthTokenProvider { OauthTokenProvider { - service_account_to_impersonate, - default_service_account_oauth_token: None, - impersonated_service_account_oauth_token: None, + account_to_impersonate: account_to_impersonate, + default_account_token: None, + impersonated_account_token: None, } } @@ -87,9 +100,9 @@ impl OauthTokenProvider { /// impersonation is taking place, provides the default service account /// Oauth token. fn ensure_storage_access_oauth_token(&mut self) -> Result { - match self.service_account_to_impersonate { + match self.account_to_impersonate { Some(_) => self.ensure_impersonated_service_account_oauth_token(), - None => self.ensure_default_service_account_oauth_token(), + None => self.ensure_default_account_token(), } } @@ -97,8 +110,8 @@ impl OauthTokenProvider { /// is valid. Otherwise obtains and returns a new one. /// The returned value is an owned reference because the token owned by this /// struct could change while the caller is still holding the returned token - fn ensure_default_service_account_oauth_token(&mut self) -> Result { - if let Some(token) = &self.default_service_account_oauth_token { + fn ensure_default_account_token(&mut self) -> Result { + if let Some(token) = &self.default_account_token { if !token.expired() { return Ok(token.token.clone()); } @@ -125,7 +138,7 @@ impl OauthTokenProvider { return Err(anyhow!("unexpected token type {}", response.token_type)); } - self.default_service_account_oauth_token = Some(OauthToken { + self.default_account_token = Some(OauthToken { token: response.access_token.clone(), expiration: Utc::now() + Duration::seconds(response.expires_in), }); @@ -136,25 +149,22 @@ impl OauthTokenProvider { /// Returns the current OAuth token for the impersonated service account, if /// it is valid. Otherwise obtains and returns a new one. fn ensure_impersonated_service_account_oauth_token(&mut self) -> Result { - if self.service_account_to_impersonate.is_none() { + if self.account_to_impersonate.is_none() { return Err(anyhow!("no service account to impersonate was provided")); } - if let Some(token) = &self.impersonated_service_account_oauth_token { + if let Some(token) = &self.impersonated_account_token { if !token.expired() { return Ok(token.token.clone()); } } - let service_account_to_impersonate = self.service_account_to_impersonate.clone().unwrap(); + let service_account_to_impersonate = self.account_to_impersonate.clone().unwrap(); // API reference: // https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/generateAccessToken let request_url = format!("https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}:generateAccessToken", service_account_to_impersonate); - let auth = format!( - "Bearer {}", - self.ensure_default_service_account_oauth_token()? - ); + let auth = format!("Bearer {}", self.ensure_default_account_token()?); let http_response = ureq::post(&request_url) .set("Authorization", &auth) .set("Content-Type", "application/json") @@ -181,7 +191,7 @@ impl OauthTokenProvider { let response = http_response .into_json_deserialize::() .context("failed to deserialize response from IAM API")?; - self.impersonated_service_account_oauth_token = Some(OauthToken { + self.impersonated_account_token = Some(OauthToken { token: response.access_token.clone(), expiration: response.expire_time, }); @@ -215,8 +225,15 @@ impl GCSTransport { } impl Transport for GCSTransport { + fn path(&self) -> String { + self.path.to_string() + } + fn get(&mut self, key: &str) -> Result> { - info!("get {} as {:?}", self.path, self.oauth_token_provider); + info!( + "get {}/{} as {:?}", + self.path, key, self.oauth_token_provider + ); // Per API reference, the object key must be URL encoded. // API reference: https://cloud.google.com/storage/docs/json_api/v1/objects/get let encoded_key = urlencoding::encode(&[&self.path.key, key].concat()); @@ -252,7 +269,10 @@ impl Transport for GCSTransport { } fn put(&mut self, key: &str) -> Result> { - info!("get {} as {:?}", self.path, self.oauth_token_provider); + info!( + "put {}/{} as {:?}", + self.path, key, self.oauth_token_provider + ); // The Oauth token will only be used once, during the call to // StreamingTransferWriter::new, so we don't have to worry about it // expiring during the lifetime of that object, and so obtain a token diff --git a/facilitator/src/transport/local.rs b/facilitator/src/transport/local.rs index 92911b62c..8e66450cc 100644 --- a/facilitator/src/transport/local.rs +++ b/facilitator/src/transport/local.rs @@ -29,6 +29,10 @@ impl LocalFileTransport { } impl Transport for LocalFileTransport { + fn path(&self) -> String { + self.directory.to_string_lossy().to_string() + } + fn get(&mut self, key: &str) -> Result> { let path = self.directory.join(LocalFileTransport::relative_path(key)); let f = diff --git a/facilitator/src/transport/s3.rs b/facilitator/src/transport/s3.rs index dcffdc8f9..fe17d1ffe 100644 --- a/facilitator/src/transport/s3.rs +++ b/facilitator/src/transport/s3.rs @@ -182,8 +182,12 @@ impl S3Transport { type ClientProvider = Box) -> Result>; impl Transport for S3Transport { + fn path(&self) -> String { + self.path.to_string() + } + fn get(&mut self, key: &str) -> Result> { - info!("get {} as {:?}", self.path, self.iam_role); + info!("get {}/{} as {:?}", self.path, key, self.iam_role); let mut runtime = basic_runtime()?; let client = (self.client_provider)(&self.path.region, self.iam_role.clone())?; let get_output = runtime @@ -200,7 +204,7 @@ impl Transport for S3Transport { } fn put(&mut self, key: &str) -> Result> { - info!("put {} as {:?}", self.path, self.iam_role); + info!("put {}/{} as {:?}", self.path, key, self.iam_role); let writer = MultipartUploadWriter::new( self.path.bucket.to_owned(), format!("{}{}", &self.path.key, key), diff --git a/facilitator/tests/integration_tests.rs b/facilitator/tests/integration_tests.rs index a99740a20..5bcaec85f 100644 --- a/facilitator/tests/integration_tests.rs +++ b/facilitator/tests/integration_tests.rs @@ -52,7 +52,7 @@ fn end_to_end() { &PrivateKey::from_base64(DEFAULT_FACILITATOR_ECIES_PRIVATE_KEY).unwrap(), &default_ingestor_private_key(), 10, - 10, + 16, 0.11, 100, 100, @@ -69,7 +69,7 @@ fn end_to_end() { &PrivateKey::from_base64(DEFAULT_FACILITATOR_ECIES_PRIVATE_KEY).unwrap(), &default_ingestor_private_key(), 10, - 10, + 14, 0.11, 100, 100, @@ -260,6 +260,7 @@ fn end_to_end() { &mut *pha_aggregation_transport.transport, ); let pha_sum_part = pha_aggregation_batch_reader.header(&pha_pub_keys).unwrap(); + assert_eq!(pha_sum_part.total_individual_clients, 30); let pha_sum_fields = pha_sum_part.sum().unwrap(); let pha_invalid_packet_reader = pha_aggregation_batch_reader.packet_file_reader(&pha_sum_part); @@ -285,6 +286,7 @@ fn end_to_end() { let facilitator_sum_part = facilitator_aggregation_batch_reader .header(&facilitator_pub_keys) .unwrap(); + assert_eq!(facilitator_sum_part.total_individual_clients, 30); let facilitator_sum_fields = facilitator_sum_part.sum().unwrap(); let facilitator_invalid_packet_reader = @@ -310,13 +312,4 @@ fn end_to_end() { \tfacilitator clients: {}\n\tpha clients: {}", facilitator_sum_part.total_individual_clients, pha_sum_part.total_individual_clients ); - - assert_eq!( - reconstructed.len() as i64, - facilitator_sum_part.total_individual_clients, - "Total individual clients does not match the length of sum\n\ - \ttotal individual clients: {}\n\tlength of sum: {}", - facilitator_sum_part.total_individual_clients, - reconstructed.len() - ); } diff --git a/terraform/main.tf b/terraform/main.tf index e95f899b1..be89a9031 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -67,6 +67,24 @@ variable "is_first" { description = "Whether the data share processors created by this environment are \"first\" or \"PHA servers\"" } +variable "aggregation_period" { + type = string + default = "3h" + description = <