diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 0b5db2a66..0ea6775ac 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -355,9 +355,9 @@ impl Aggregator { ) -> Result { let task_aggregator = match self.task_aggregator_for(task_id).await? { Some(task_aggregator) => { - if !auth_token - .map(|t| task_aggregator.task.check_aggregator_auth_token(&t)) - .unwrap_or(false) + if !task_aggregator + .task + .check_aggregator_auth_token(auth_token.as_ref()) { return Err(Error::UnauthorizedRequest(*task_id)); } @@ -426,9 +426,9 @@ impl Aggregator { auth_token.as_ref(), ) .await?; - } else if !auth_token - .map(|t| task_aggregator.task.check_aggregator_auth_token(&t)) - .unwrap_or(false) + } else if !task_aggregator + .task + .check_aggregator_auth_token(auth_token.as_ref()) { return Err(Error::UnauthorizedRequest(*task_id)); } @@ -465,9 +465,9 @@ impl Aggregator { if task_aggregator.task.role() != &Role::Leader { return Err(Error::UnrecognizedTask(*task_id)); } - if !auth_token - .map(|t| task_aggregator.task.check_collector_auth_token(&t)) - .unwrap_or(false) + if !task_aggregator + .task + .check_collector_auth_token(auth_token.as_ref()) { return Err(Error::UnauthorizedRequest(*task_id)); } @@ -494,9 +494,9 @@ impl Aggregator { if task_aggregator.task.role() != &Role::Leader { return Err(Error::UnrecognizedTask(*task_id)); } - if !auth_token - .map(|t| task_aggregator.task.check_collector_auth_token(&t)) - .unwrap_or(false) + if !task_aggregator + .task + .check_collector_auth_token(auth_token.as_ref()) { return Err(Error::UnauthorizedRequest(*task_id)); } @@ -520,9 +520,9 @@ impl Aggregator { if task_aggregator.task.role() != &Role::Leader { return Err(Error::UnrecognizedTask(*task_id)); } - if !auth_token - .map(|t| task_aggregator.task.check_collector_auth_token(&t)) - .unwrap_or(false) + if !task_aggregator + .task + .check_collector_auth_token(auth_token.as_ref()) { return Err(Error::UnauthorizedRequest(*task_id)); } @@ -566,9 +566,9 @@ impl Aggregator { peer_aggregator.collector_hpke_config() } else { - if !auth_token - .map(|t| task_aggregator.task.check_aggregator_auth_token(&t)) - .unwrap_or(false) + if !task_aggregator + .task + .check_aggregator_auth_token(auth_token.as_ref()) { return Err(Error::UnauthorizedRequest(*task_id)); } diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 54a3a85e0..3d036c1ca 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -4,7 +4,7 @@ use crate::SecretBytes; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use derivative::Derivative; use janus_core::{ - auth_tokens::AuthenticationToken, + auth_tokens::{AuthenticationToken, AuthenticationTokenHash}, hpke::{generate_hpke_config_and_private_key, HpkeKeypair}, time::TimeExt, url_ensure_trailing_slash, @@ -401,22 +401,35 @@ impl Task { } } - /// Checks if the given aggregator authentication token is valid (i.e. matches with an + /// Checks if the given aggregator authentication token is valid (i.e. matches with the /// authentication token recognized by this task). - pub fn check_aggregator_auth_token(&self, auth_token: &AuthenticationToken) -> bool { - match self.aggregator_auth_token { - Some(ref t) => t == auth_token, - None => false, - } + pub fn check_aggregator_auth_token( + &self, + incoming_auth_token: Option<&AuthenticationToken>, + ) -> bool { + // TODO(#1509): leader should hold only an AuthenticationToken and refuse to use it for + // incoming token validation. Helper should hold only an AuthenticationTokenHash, making the + // AuthenticationTokenHash::from call here unnecessary. + self.aggregator_auth_token() + .map(AuthenticationTokenHash::from) + .zip(incoming_auth_token) + .map(|(own_token_hash, incoming_token)| own_token_hash.validate(incoming_token)) + .unwrap_or(false) } - /// Checks if the given collector authentication token is valid (i.e. matches with an + /// Checks if the given collector authentication token is valid (i.e. matches with the /// authentication token recognized by this task). - pub fn check_collector_auth_token(&self, auth_token: &AuthenticationToken) -> bool { - match self.collector_auth_token { - Some(ref t) => t == auth_token, - None => false, - } + pub fn check_collector_auth_token( + &self, + incoming_auth_token: Option<&AuthenticationToken>, + ) -> bool { + // TODO(#1509): Leader should hold only an AuthenticaitonTokenHash, making the + // AuthenticationTokenHash::from call here unnecessary. + self.collector_auth_token() + .map(AuthenticationTokenHash::from) + .zip(incoming_auth_token) + .map(|(own_token_hash, incoming_token)| own_token_hash.validate(incoming_token)) + .unwrap_or(false) } /// Returns the [`VerifyKey`] used by this aggregator to prepare report shares with other diff --git a/core/src/auth_tokens.rs b/core/src/auth_tokens.rs index 1bf3ce6a1..2c7c54043 100644 --- a/core/src/auth_tokens.rs +++ b/core/src/auth_tokens.rs @@ -2,8 +2,11 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use derivative::Derivative; use http::{header::AUTHORIZATION, HeaderValue}; use rand::{distributions::Standard, prelude::Distribution}; -use ring::constant_time; -use serde::{de::Error, Deserialize, Deserializer, Serialize}; +use ring::{ + constant_time, + digest::{digest, SHA256, SHA256_OUTPUT_LEN}, +}; +use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; use std::str; /// HTTP header where auth tokens are provided in messages between participants. @@ -291,9 +294,115 @@ impl Distribution for Standard { } } +/// The hash of an authentication token, which may be used to validate tokens in incoming requests +/// but not to authenticate outgoing requests. +#[derive(Clone, Derivative, Deserialize, Serialize, Eq)] +#[derivative(Debug)] +#[serde(tag = "type", content = "hash")] +#[non_exhaustive] +pub enum AuthenticationTokenHash { + /// A bearer token, presented as the value of the "Authorization" HTTP header as specified in + /// [RFC 6750 section 2.1][1]. + /// + /// The token is not necessarily an OAuth token. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc6750#section-2.1 + Bearer( + #[derivative(Debug = "ignore")] + #[serde( + serialize_with = "AuthenticationTokenHash::serialize_contents", + deserialize_with = "AuthenticationTokenHash::deserialize_contents" + )] + [u8; SHA256_OUTPUT_LEN], + ), + + /// Token presented as the value of the "DAP-Auth-Token" HTTP header. Conforms to + /// [draft-dcook-ppm-dap-interop-test-design-03][1], sections [4.3.3][2] and [4.4.2][3], and + /// [draft-ietf-dap-ppm-01 section 3.2][4]. + /// + /// [1]: https://datatracker.ietf.org/doc/html/draft-dcook-ppm-dap-interop-test-design-03 + /// [2]: https://datatracker.ietf.org/doc/html/draft-dcook-ppm-dap-interop-test-design-03#section-4.3.3 + /// [3]: https://datatracker.ietf.org/doc/html/draft-dcook-ppm-dap-interop-test-design-03#section-4.4.2 + /// [4]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-01#name-https-sender-authentication + DapAuth( + #[derivative(Debug = "ignore")] + #[serde( + serialize_with = "AuthenticationTokenHash::serialize_contents", + deserialize_with = "AuthenticationTokenHash::deserialize_contents" + )] + [u8; SHA256_OUTPUT_LEN], + ), +} + +impl AuthenticationTokenHash { + /// Returns true if the incoming unhashed token matches this token hash, false otherwise. + pub fn validate(&self, incoming_token: &AuthenticationToken) -> bool { + &Self::from(incoming_token) == self + } + + fn serialize_contents( + value: &[u8; SHA256_OUTPUT_LEN], + serializer: S, + ) -> Result { + serializer.serialize_str(&URL_SAFE_NO_PAD.encode(value)) + } + + fn deserialize_contents<'de, D>(deserializer: D) -> Result<[u8; SHA256_OUTPUT_LEN], D::Error> + where + D: Deserializer<'de>, + { + let b64_digest: String = Deserialize::deserialize(deserializer)?; + let decoded = URL_SAFE_NO_PAD + .decode(b64_digest) + .map_err(D::Error::custom)?; + + decoded + .try_into() + .map_err(|_| D::Error::custom("digest has wrong length")) + } +} + +impl From<&AuthenticationToken> for AuthenticationTokenHash { + fn from(value: &AuthenticationToken) -> Self { + // unwrap safety: try_into is converting from &[u8] to [u8; SHA256_OUTPUT_LEN]. SHA256 + // output will always be that length, so this conversion should never fail. + let digest = digest(&SHA256, value.as_ref()).as_ref().try_into().unwrap(); + + match value { + AuthenticationToken::Bearer(_) => Self::Bearer(digest), + AuthenticationToken::DapAuth(_) => Self::DapAuth(digest), + } + } +} + +impl PartialEq for AuthenticationTokenHash { + fn eq(&self, other: &Self) -> bool { + let (self_digest, other_digest) = match (self, other) { + (Self::Bearer(self_digest), Self::Bearer(other_digest)) => (self_digest, other_digest), + (Self::DapAuth(self_digest), Self::DapAuth(other_digest)) => { + (self_digest, other_digest) + } + _ => return false, + }; + + // We attempt constant-time comparisons of the token data to mitigate timing attacks. + constant_time::verify_slices_are_equal(self_digest.as_ref(), other_digest.as_ref()).is_ok() + } +} + +impl AsRef<[u8]> for AuthenticationTokenHash { + fn as_ref(&self) -> &[u8] { + match self { + Self::Bearer(inner) => inner.as_slice(), + Self::DapAuth(inner) => inner.as_slice(), + } + } +} + #[cfg(test)] mod tests { - use crate::auth_tokens::AuthenticationToken; + use crate::auth_tokens::{AuthenticationToken, AuthenticationTokenHash}; + use rand::random; #[test] fn valid_dap_auth_token() { @@ -330,4 +439,46 @@ mod tests { serde_yaml::from_str::("{type: \"Bearer\", token: \"AAAA==AAA\"}") .unwrap_err(); } + + #[rstest::rstest] + #[case::bearer(r#"{ type: "Bearer", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4vdQ" }"#)] + #[case::dap_auth(r#"{ type: "DapAuth", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4vdQ" }"#)] + #[test] + fn serde_aggregator_token_hash_valid(#[case] yaml: &str) { + serde_yaml::from_str::(yaml).unwrap(); + } + + #[rstest::rstest] + #[case::bearer_token_invalid_encoding(r#"{ type: "Bearer", hash: "+" }"#)] + #[case::bearer_token_wrong_length( + r#"{ type: "Bearer", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4" }"# + )] + #[case::dap_auth_token_invalid_encoding(r#"{ type: "DapAuth", hash: "+" }"#)] + #[case::dap_auth_token_wrong_length( + r#"{ type: "DapAuth", hash: "MJOoBO_ysLEuG_lv2C37eEOf1Ngetsr-Ers0ZYj4" }"# + )] + #[test] + fn serde_aggregator_token_hash_invalid(#[case] yaml: &str) { + serde_yaml::from_str::(yaml).unwrap_err(); + } + + #[test] + fn validate_token() { + let dap_auth_token_1 = AuthenticationToken::DapAuth(random()); + let dap_auth_token_2 = AuthenticationToken::DapAuth(random()); + let bearer_token_1 = AuthenticationToken::Bearer(random()); + let bearer_token_2 = AuthenticationToken::Bearer(random()); + + assert_eq!(dap_auth_token_1, dap_auth_token_1); + assert_ne!(dap_auth_token_1, dap_auth_token_2); + assert_eq!(bearer_token_1, bearer_token_1); + assert_ne!(bearer_token_1, bearer_token_2); + assert_ne!(dap_auth_token_1, bearer_token_1); + + assert!(AuthenticationTokenHash::from(&dap_auth_token_1).validate(&dap_auth_token_1)); + assert!(!AuthenticationTokenHash::from(&dap_auth_token_1).validate(&dap_auth_token_2)); + assert!(AuthenticationTokenHash::from(&bearer_token_1).validate(&bearer_token_1)); + assert!(!AuthenticationTokenHash::from(&bearer_token_1).validate(&bearer_token_2)); + assert!(!AuthenticationTokenHash::from(&dap_auth_token_1).validate(&bearer_token_1)); + } }