diff --git a/integration-tests/fastauth/src/env/containers.rs b/integration-tests/fastauth/src/env/containers.rs index da9eb41d9..d7d1bf047 100644 --- a/integration-tests/fastauth/src/env/containers.rs +++ b/integration-tests/fastauth/src/env/containers.rs @@ -40,6 +40,7 @@ use testcontainers::{ use tokio::io::AsyncWriteExt; use tracing; +use std::collections::HashMap; use std::fs; use crate::env::{Context, LeaderNodeApi, SignerNodeApi}; @@ -522,6 +523,8 @@ impl SignerNode<'_> { cipher_key: &GenericArray, ) -> anyhow::Result> { tracing::info!("Running signer node container {}...", node_id); + let mut jwt_signature_pk_urls = HashMap::new(); + jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone()); let args = mpc_recovery::Cli::StartSign { env: ctx.env.clone(), node_id: node_id as u64, @@ -530,7 +533,7 @@ impl SignerNode<'_> { cipher_key: Some(hex::encode(cipher_key)), gcp_project_id: ctx.gcp_project_id.clone(), gcp_datastore_url: Some(ctx.datastore.address.clone()), - jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_url.clone(), + jwt_signature_pk_urls, logging_options: logging::Options::default(), } .into_str_args(); @@ -636,6 +639,8 @@ impl<'a> LeaderNode<'a> { pub async fn run(ctx: &Context<'a>, sign_nodes: Vec) -> anyhow::Result> { tracing::info!("Running leader node container..."); let account_creator = &ctx.relayer_ctx.creator_account; + let mut jwt_signature_pk_urls = HashMap::new(); + jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone()); let args = mpc_recovery::Cli::StartLeader { env: ctx.env.clone(), web_port: Self::CONTAINER_PORT, @@ -667,7 +672,7 @@ impl<'a> LeaderNode<'a> { fast_auth_partners_filepath: None, gcp_project_id: ctx.gcp_project_id.clone(), gcp_datastore_url: Some(ctx.datastore.address.to_string()), - jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_url.to_string(), + jwt_signature_pk_urls, logging_options: logging::Options::default(), } .into_str_args(); diff --git a/integration-tests/fastauth/src/env/local.rs b/integration-tests/fastauth/src/env/local.rs index 2b2ea3010..a0325357a 100644 --- a/integration-tests/fastauth/src/env/local.rs +++ b/integration-tests/fastauth/src/env/local.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use crate::env::{LeaderNodeApi, SignerNodeApi}; use crate::mpc::{self, NodeProcess}; use crate::util; @@ -29,6 +31,8 @@ impl SignerNode { cipher_key: &GenericArray, ) -> anyhow::Result { let web_port = util::pick_unused_port().await?; + let mut jwt_signature_pk_urls = HashMap::new(); + jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone()); let cli = mpc_recovery::Cli::StartSign { env: ctx.env.clone(), node_id, @@ -37,7 +41,7 @@ impl SignerNode { cipher_key: Some(hex::encode(cipher_key)), gcp_project_id: ctx.gcp_project_id.clone(), gcp_datastore_url: Some(ctx.datastore.local_address.clone()), - jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_local_url.clone(), + jwt_signature_pk_urls, logging_options: logging::Options::default(), }; @@ -87,6 +91,8 @@ impl LeaderNode { tracing::info!("Running leader node..."); let account_creator = &ctx.relayer_ctx.creator_account; let web_port = util::pick_unused_port().await?; + let mut jwt_signature_pk_urls = HashMap::new(); + jwt_signature_pk_urls.insert(ctx.issuer.clone(), ctx.oidc_provider.jwt_pk_url.clone()); let cli = mpc_recovery::Cli::StartLeader { env: ctx.env.clone(), web_port, @@ -118,7 +124,7 @@ impl LeaderNode { ), gcp_project_id: ctx.gcp_project_id.clone(), gcp_datastore_url: Some(ctx.datastore.local_address.clone()), - jwt_signature_pk_url: ctx.oidc_provider.jwt_pk_local_url.clone(), + jwt_signature_pk_urls, logging_options: logging::Options::default(), }; diff --git a/mpc-recovery/src/leader_node/mod.rs b/mpc-recovery/src/leader_node/mod.rs index 1bfde64aa..9d646fd41 100644 --- a/mpc-recovery/src/leader_node/mod.rs +++ b/mpc-recovery/src/leader_node/mod.rs @@ -34,6 +34,7 @@ use near_primitives::delegate_action::{DelegateAction, NonDelegateAction}; use near_primitives::transaction::{Action, DeleteAccountAction, DeleteKeyAction}; use near_primitives::types::AccountId; use prometheus::{Encoder, TextEncoder}; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; @@ -47,7 +48,7 @@ pub struct Config { // TODO: temporary solution pub account_creator_signer: KeyRotatingSigner, pub partners: PartnerList, - pub jwt_signature_pk_url: String, + pub jwt_signature_pk_urls: HashMap, } pub async fn run(config: Config) { @@ -59,7 +60,7 @@ pub async fn run(config: Config) { near_root_account, account_creator_signer, partners, - jwt_signature_pk_url, + jwt_signature_pk_urls, } = config; let _span = tracing::debug_span!("run", env, port); tracing::debug!(?sign_nodes, "running a leader node"); @@ -74,7 +75,7 @@ pub async fn run(config: Config) { near_root_account: near_root_account.parse().unwrap(), account_creator_signer, partners, - jwt_signature_pk_url, + jwt_signature_pk_urls, }); // Get keys from all sign nodes, and broadcast them out as a set. @@ -198,7 +199,7 @@ struct LeaderState { // TODO: temporary solution account_creator_signer: KeyRotatingSigner, partners: PartnerList, - jwt_signature_pk_url: String, + jwt_signature_pk_urls: HashMap, } async fn mpc_public_key( @@ -302,7 +303,7 @@ async fn process_user_credentials( &request.oidc_token, Some(&state.partners.oidc_providers()), &state.reqwest_client, - &state.jwt_signature_pk_url, + &state.jwt_signature_pk_urls, ) .await .map_err(LeaderNodeError::OidcVerificationFailed)?; @@ -334,7 +335,7 @@ async fn process_new_account( &request.oidc_token, Some(&state.partners.oidc_providers()), &state.reqwest_client, - &state.jwt_signature_pk_url, + &state.jwt_signature_pk_urls, ) .await .map_err(LeaderNodeError::OidcVerificationFailed)?; @@ -477,7 +478,7 @@ async fn process_sign( &request.oidc_token, Some(&state.partners.oidc_providers()), &state.reqwest_client, - &state.jwt_signature_pk_url, + &state.jwt_signature_pk_urls, ) .await .map_err(LeaderNodeError::OidcVerificationFailed)?; diff --git a/mpc-recovery/src/lib.rs b/mpc-recovery/src/lib.rs index 87ec5992a..5bd9246b4 100644 --- a/mpc-recovery/src/lib.rs +++ b/mpc-recovery/src/lib.rs @@ -1,6 +1,7 @@ // TODO: FIXME: Remove this once we have a better way to handle these large errors #![allow(clippy::result_large_err)] +use std::collections::HashMap; use std::path::PathBuf; use aes_gcm::aead::consts::U32; @@ -113,9 +114,9 @@ pub enum Cli { /// GCP datastore URL #[arg(long, env("MPC_RECOVERY_GCP_DATASTORE_URL"))] gcp_datastore_url: Option, - /// URL to the public key used to sign JWT tokens - #[arg(long, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URL"))] - jwt_signature_pk_url: String, + /// URLs of the public keys used by all issuers + #[arg(long, value_parser = parse_json_str::>, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URLS"))] + jwt_signature_pk_urls: HashMap, /// Enables export of span data using opentelemetry protocol. #[clap(flatten)] logging_options: logging::Options, @@ -142,9 +143,9 @@ pub enum Cli { /// GCP datastore URL #[arg(long, env("MPC_RECOVERY_GCP_DATASTORE_URL"))] gcp_datastore_url: Option, - /// URL to the public key used to sign JWT tokens - #[arg(long, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URL"))] - jwt_signature_pk_url: String, + /// URLs of the public keys used by all issuers + #[arg(long, value_parser = parse_json_str::>, env("MPC_RECOVERY_JWT_SIGNATURE_PK_URLS"))] + jwt_signature_pk_urls: HashMap, /// Enables export of span data using opentelemetry protocol. #[clap(flatten)] logging_options: logging::Options, @@ -203,7 +204,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> { fast_auth_partners_filepath: partners_filepath, gcp_project_id, gcp_datastore_url, - jwt_signature_pk_url, + jwt_signature_pk_urls, logging_options, } => { let _subscriber_guard = logging::subscribe_global( @@ -231,7 +232,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> { near_root_account, account_creator_signer, partners, - jwt_signature_pk_url, + jwt_signature_pk_urls, }; run_leader_node(config).await; @@ -244,7 +245,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> { web_port, gcp_project_id, gcp_datastore_url, - jwt_signature_pk_url, + jwt_signature_pk_urls, logging_options, } => { let _subscriber_guard = logging::subscribe_global( @@ -272,7 +273,7 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> { node_key: sk_share, cipher, port: web_port, - jwt_signature_pk_url, + jwt_signature_pk_urls, }; run_sign_node(config).await; } @@ -428,7 +429,7 @@ impl Cli { fast_auth_partners_filepath, gcp_project_id, gcp_datastore_url, - jwt_signature_pk_url, + jwt_signature_pk_urls, logging_options, } => { let mut buf = vec![ @@ -445,8 +446,6 @@ impl Cli { account_creator_id.to_string(), "--gcp-project-id".to_string(), gcp_project_id, - "--jwt-signature-pk-url".to_string(), - jwt_signature_pk_url, ]; if let Some(partners) = fast_auth_partners { @@ -465,6 +464,11 @@ impl Cli { buf.push("--sign-nodes".to_string()); buf.push(sign_node); } + + let jwt_signature_pk_urls = serde_json::to_string(&jwt_signature_pk_urls).unwrap(); + buf.push("--jwt-signature-pk-urls".to_string()); + buf.push(jwt_signature_pk_urls); + let account_creator_sk = serde_json::to_string(&account_creator_sk).unwrap(); buf.push("--account-creator-sk".to_string()); buf.push(account_creator_sk); @@ -480,7 +484,7 @@ impl Cli { sk_share, gcp_project_id, gcp_datastore_url, - jwt_signature_pk_url, + jwt_signature_pk_urls, logging_options, } => { let mut buf = vec![ @@ -493,8 +497,6 @@ impl Cli { web_port.to_string(), "--gcp-project-id".to_string(), gcp_project_id, - "--jwt-signature-pk-url".to_string(), - jwt_signature_pk_url, ]; if let Some(key) = cipher_key { buf.push("--cipher-key".to_string()); @@ -508,6 +510,11 @@ impl Cli { buf.push("--gcp-datastore-url".to_string()); buf.push(gcp_datastore_url); } + + let jwt_signature_pk_urls = serde_json::to_string(&jwt_signature_pk_urls).unwrap(); + buf.push("--jwt-signature-pk-urls".to_string()); + buf.push(jwt_signature_pk_urls); + buf.extend(logging_options.into_str_args()); buf diff --git a/mpc-recovery/src/oauth.rs b/mpc-recovery/src/oauth.rs index a22e57a1c..f4f84aee9 100644 --- a/mpc-recovery/src/oauth.rs +++ b/mpc-recovery/src/oauth.rs @@ -1,5 +1,8 @@ +use anyhow::{Context, Result}; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use jsonwebtoken::{Algorithm, DecodingKey}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::HashMap; use crate::firewall::allowed::OidcProviderList; @@ -13,15 +16,21 @@ pub async fn verify_oidc_token( token: &OidcToken, oidc_providers: Option<&OidcProviderList>, client: &reqwest::Client, - jwt_signature_pk_url: &str, + jwt_signature_pk_urls: &HashMap, ) -> anyhow::Result { - let public_keys = get_pagoda_firebase_public_keys(client, jwt_signature_pk_url) + let (_, claims, _) = token.decode_unverified()?; + let issuer = &claims.iss; + + let jwks_url = jwt_signature_pk_urls + .get(issuer) + .ok_or_else(|| anyhow::anyhow!("No JWKS URL found for issuer: {}", issuer))?; + + let public_keys = get_public_keys(client, jwks_url) .await - .map_err(|e| anyhow::anyhow!("failed to get Firebase public key: {e}"))?; - tracing::info!("verify_oidc_token firebase public keys: {public_keys:?}"); + .map_err(|e| anyhow::anyhow!("failed to get public keys: {e}"))?; + tracing::info!("verify_oidc_token public keys: {public_keys:?}"); - let mut last_occured_error = - anyhow::anyhow!("Unexpected error. Firebase public keys not found"); + let mut last_occured_error = anyhow::anyhow!("Unexpected error. Public keys not found"); for public_key in public_keys { match validate_jwt(token, public_key.as_bytes(), oidc_providers) { Ok(claims) => { @@ -99,13 +108,49 @@ impl IdTokenClaims { } } -pub async fn get_pagoda_firebase_public_keys( - client: &reqwest::Client, - jwt_signature_pk_url: &str, -) -> anyhow::Result> { - let response = client.get(jwt_signature_pk_url).send().await?; - let json: HashMap = response.json().await?; - Ok(json.into_values().collect()) +pub async fn get_public_keys(client: &reqwest::Client, jwks_url: &str) -> Result> { + let response = client + .get(jwks_url) + .send() + .await + .context("Failed to send request")?; + + let json: Value = response.json().await.context("Failed to parse JSON")?; + + match json { + Value::Object(obj) if obj.contains_key("keys") => parse_jwks_format(&obj), + Value::Object(obj) => parse_firebase_format(&obj), + _ => { + tracing::warn!("Unexpected response format from {}", jwks_url); + Ok(vec![]) + } + } +} + +fn parse_jwks_format(obj: &serde_json::Map) -> Result> { + obj["keys"] + .as_array() + .context("'keys' is not an array")? + .iter() + .filter_map(|key| match (key["n"].as_str(), key["e"].as_str()) { + (Some(n), Some(e)) => Some(format_rsa_key(n, e)), + _ => None, + }) + .collect::>>() +} + +fn parse_firebase_format(obj: &serde_json::Map) -> Result> { + Ok(obj + .values() + .filter_map(|value| value.as_str().map(String::from)) + .collect()) +} + +fn format_rsa_key(n: &str, e: &str) -> Result { + Ok(format!( + "-----BEGIN PUBLIC KEY-----\n{}\n-----END PUBLIC KEY-----", + BASE64.encode(format!("{}:{}", n, e)) + )) } #[cfg(test)] @@ -124,7 +169,8 @@ mod tests { let url = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"; let client = reqwest::Client::new(); - let pk = get_pagoda_firebase_public_keys(&client, url).await.unwrap(); + let pk = get_public_keys(&client, url).await.unwrap(); + assert!(!pk.is_empty()); } diff --git a/mpc-recovery/src/sign_node/mod.rs b/mpc-recovery/src/sign_node/mod.rs index e43dd7abb..22419c3ec 100644 --- a/mpc-recovery/src/sign_node/mod.rs +++ b/mpc-recovery/src/sign_node/mod.rs @@ -24,6 +24,7 @@ use multi_party_eddsa::protocols::{self, ExpandedKeyPair}; use near_primitives::hash::hash; use near_primitives::signable_message::{SignableMessage, SignableMessageType}; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; @@ -39,7 +40,7 @@ pub struct Config { pub node_key: ExpandedKeyPair, pub cipher: Aes256Gcm, pub port: u16, - pub jwt_signature_pk_url: String, + pub jwt_signature_pk_urls: HashMap, } pub async fn run(config: Config) { @@ -50,7 +51,7 @@ pub async fn run(config: Config) { node_key, cipher, port, - jwt_signature_pk_url, + jwt_signature_pk_urls, } = config; let our_index = usize::try_from(our_index).expect("This index is way to big"); @@ -66,7 +67,7 @@ pub async fn run(config: Config) { cipher, signing_state: SigningState::new(), node_info: NodeInfo::new(our_index, pk_set.map(|set| set.public_keys)), - jwt_signature_pk_url, + jwt_signature_pk_urls, }); let app = Router::new() @@ -101,7 +102,7 @@ struct SignNodeState { cipher: Aes256Gcm, signing_state: SigningState, node_info: NodeInfo, - jwt_signature_pk_url: String, + jwt_signature_pk_urls: HashMap, } async fn get_or_generate_user_creds( @@ -213,7 +214,7 @@ async fn process_commit( &request.oidc_token, None, &state.reqwest_client, - &state.jwt_signature_pk_url, + &state.jwt_signature_pk_urls, ) .await .map_err(SignNodeError::OidcVerificationFailed)?; @@ -369,7 +370,7 @@ async fn process_public_key( &request.oidc_token, None, &state.reqwest_client, - &state.jwt_signature_pk_url, + &state.jwt_signature_pk_urls, ) .await .map_err(SignNodeError::OidcVerificationFailed)?; diff --git a/mpc-recovery/src/sign_node/oidc.rs b/mpc-recovery/src/sign_node/oidc.rs index e7b2ca2ef..df016aac3 100644 --- a/mpc-recovery/src/sign_node/oidc.rs +++ b/mpc-recovery/src/sign_node/oidc.rs @@ -98,6 +98,21 @@ impl OidcToken { Ok((header, claims, signature.into())) } + + // NOTE: code taken directly from our implementation of token.decode but without the verification step + pub fn decode_unverified(&self) -> anyhow::Result<(jwt::Header, IdTokenClaims, String)> { + let mut parts = self.as_ref().rsplitn(2, '.'); + let (Some(signature), Some(message)) = (parts.next(), parts.next()) else { + anyhow::bail!("could not split into signature and message for OIDC token"); + }; + let mut parts = message.rsplitn(2, '.'); + let (Some(payload), Some(header)) = (parts.next(), parts.next()) else { + anyhow::bail!("could not split into payload and header for OIDC token"); + }; + let header: jwt::Header = serde_json::from_slice(&b64_decode(header)?)?; + let claims: IdTokenClaims = serde_json::from_slice(&b64_decode(payload)?)?; + Ok((header, claims, signature.to_string())) + } } fn b64_decode>(input: T) -> anyhow::Result> {