diff --git a/Cargo.lock b/Cargo.lock index 2ad5854..e9006f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2723,6 +2723,7 @@ dependencies = [ "dropshot-authorization-header", "dropshot-verified-body", "google-cloudkms1", + "hex", "http", "hyper", "hyper-tls", @@ -2731,6 +2732,7 @@ dependencies = [ "oauth2", "octorust", "partial-struct", + "rand", "rand_core", "regex", "reqwest", @@ -2789,6 +2791,7 @@ dependencies = [ "chrono", "diesel", "diesel_migrations", + "http", "mockall", "partial-struct", "schemars", diff --git a/Cargo.toml b/Cargo.toml index a68ac37..2a3802f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ octorust = "0.7.0-rc.1" partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" } progenitor = { git = "https://github.com/oxidecomputer/progenitor" } progenitor-client = { git = "https://github.com/oxidecomputer/progenitor" } +rand = "0.8.5" rand_core = "0.6" regex = "1.7.1" reqwest = { version = "0.11", features = ["json", "stream"] } diff --git a/rfd-api/Cargo.toml b/rfd-api/Cargo.toml index 768c568..d9e6ea8 100644 --- a/rfd-api/Cargo.toml +++ b/rfd-api/Cargo.toml @@ -16,6 +16,7 @@ dropshot = { workspace = true } dropshot-authorization-header = { path = "../dropshot-authorization-header" } dropshot-verified-body = { workspace = true, features = ["github"] } google-cloudkms1 = { workspace = true } +hex = { workspace = true } http = { workspace = true } hyper = { workspace = true } hyper-tls = { workspace = true } @@ -23,6 +24,7 @@ jsonwebtoken = { workspace = true } oauth2 = { workspace = true } octorust = { workspace = true } partial-struct = { workspace = true } +rand = { workspace = true, features = ["std"] } rand_core = { workspace = true, features = ["std"] } regex = { workspace = true } reqwest = { workspace = true } diff --git a/rfd-api/src/authn/jwt.rs b/rfd-api/src/authn/jwt.rs index fd2595b..32d4d97 100644 --- a/rfd-api/src/authn/jwt.rs +++ b/rfd-api/src/authn/jwt.rs @@ -19,7 +19,7 @@ use thiserror::Error; use tracing::instrument; use uuid::Uuid; -use crate::{config::JwtKey, context::ApiContext, ApiPermissions}; +use crate::{config::AsymmetricKey, context::ApiContext, ApiPermissions}; #[derive(Debug, Error)] pub enum JwtError { @@ -168,7 +168,7 @@ pub enum CloudKmsError { // Signer that relies on a private key stored in GCP, and a locally store JWK. This signer never // has direct access to the private key -pub struct CloudKmSigner { +pub struct CloudKmsSigner { client: CloudKMS>, key_name: String, header: Header, @@ -187,7 +187,7 @@ pub struct CloudKmsSignatureResponse { } #[async_trait] -impl JwtSigner for CloudKmSigner { +impl JwtSigner for CloudKmsSigner { type Claims = Claims; #[instrument(skip(self, claims), err(Debug))] @@ -302,10 +302,10 @@ fn pem_to_jwk(id: &str, pem: &str) -> Result { #[instrument(skip(key), err(Debug))] pub async fn key_to_signer( - key: &JwtKey, + key: &AsymmetricKey, ) -> Result>, SignerError> { Ok(match key { - JwtKey::Local { + AsymmetricKey::Local { kid, private, public, @@ -322,7 +322,7 @@ pub async fn key_to_signer( jwk, }) } - JwtKey::Ckms { + AsymmetricKey::Ckms { kid, version, key, @@ -399,7 +399,7 @@ pub async fn key_to_signer( tracing::trace!(?header, ?jwk, "Generated Cloud KMS signer"); - Box::new(CloudKmSigner { + Box::new(CloudKmsSigner { client: gcp_kms, key_name, header, diff --git a/rfd-api/src/authn/key.rs b/rfd-api/src/authn/key.rs index 6505a60..216b439 100644 --- a/rfd-api/src/authn/key.rs +++ b/rfd-api/src/authn/key.rs @@ -1,158 +1,219 @@ -use argon2::{ - password_hash::{ - rand_core::{OsRng, RngCore}, - PasswordHash, PasswordHasher, PasswordVerifier, SaltString, - }, - Argon2, ParamsBuilder, Version, -}; +use argon2::password_hash::rand_core::{OsRng, RngCore}; +use async_trait::async_trait; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; -use std::time::Instant; -use uuid::Uuid; - -// Parameters and algorithm are based on recommendations in: -// https://soatok.blog/2022/12/29/what-we-do-in-the-etc-shadow-cryptography-with-passwords/ -// https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html -fn argon2() -> Argon2<'static> { - // Given that our parameters are static, we should never fail to build an instance - let mut params = ParamsBuilder::new(); - params.m_cost(24 * 1024).unwrap(); - params.t_cost(6).unwrap(); - params.p_cost(1).unwrap(); - - Argon2::new( - argon2::Algorithm::Argon2id, - Version::default(), - params.params().unwrap(), - ) -} +use google_cloudkms1::{hyper_rustls, CloudKMS}; +use rsa::{pkcs1::DecodeRsaPublicKey, Pkcs1v15Encrypt, PublicKey, RsaPublicKey}; +use thiserror::Error; +use tracing::instrument; + +use crate::config::AsymmetricKey; + +use super::jwt::CloudKmsError; -pub struct NewApiKey { +pub struct RawApiKey { clear: String, - hash: String, } -impl NewApiKey { - // Generate a new API key along with its hashed form - pub fn generate(id: &Uuid) -> Self { +#[derive(Debug, Error)] +pub enum KeyEncryptionFailure { + #[error(transparent)] + EncryptionFailure(#[from] EncryptorError), +} + +impl RawApiKey { + // Generate a new API key + pub fn generate() -> Self { // Generate random data to extend the token id with let mut token_raw = [0; N]; OsRng.fill_bytes(&mut token_raw); - // Append the random data to the token's id - let mut to_encode = id.as_bytes().to_vec(); - to_encode.extend(token_raw); + let clear = BASE64_URL_SAFE_NO_PAD.encode(token_raw); - let clear = BASE64_URL_SAFE_NO_PAD.encode(to_encode); - let salt = SaltString::generate(&mut OsRng); - - // Given that our Argon2 parameters are static, and our passwords are always the same size, - // we should not be able to actually hit an error case here - let hash = argon2() - .hash_password(clear.as_bytes(), &salt) - .unwrap() - .to_string(); + Self { clear } + } - Self { clear, hash } + // To get the token out of an API key it must be consumed so that it can not be used again + pub fn consume(self) -> String { + self.clear } - // To get the token and hash out of an API key it must be consumed so that it can not be used - // again - pub fn consume(self) -> (String, String) { - (self.clear, self.hash) + pub async fn encrypt( + &self, + encryptor: &dyn KeyEncryptor, + ) -> Result { + let encrypted = encryptor.encrypt(&self.clear).await?; + Ok(EncryptedApiKey { encrypted }) } } -pub struct ApiKey { - id: Uuid, - token: String, +pub struct EncryptedApiKey { + pub encrypted: String, } -impl ApiKey { - pub fn id(&self) -> &Uuid { - &self.id +impl From<&str> for RawApiKey { + fn from(value: &str) -> Self { + RawApiKey { + clear: value.to_string(), + } } +} - pub fn verify(&self, hash: &str) -> bool { - // If we somehow stored an invalid hash, immediately fail - let start = Instant::now(); - - let result = PasswordHash::new(hash) - .ok() - .and_then(|parsed_hash| { - argon2() - .verify_password(self.token.as_bytes(), &parsed_hash) - .ok() - }) - .is_some(); +// Represents a service for encrypting tokens +#[async_trait] +pub trait KeyEncryptor: Send + Sync { + async fn encrypt(&self, value: &str) -> Result; +} - let end = Instant::now(); - let duration = end - start; +#[derive(Debug, Error)] +pub enum EncryptorError { + #[error(transparent)] + CloudKms(#[from] CloudKmsError), + #[error(transparent)] + Encryption(#[from] rsa::errors::Error), + #[error(transparent)] + PemDecode(#[from] rsa::pkcs1::Error), + #[error("Failed to construct credentials for remote key storage")] + RemoteKeyAuthMissing, + #[error("Key input does not match requested output")] + SigningConfigurationMismatch, +} - tracing::debug!(?duration, "Api key verification measurement"); +// A signer that stores a local in memory key for signing new JWTs +pub struct LocalKey { + public_key: RsaPublicKey, +} - result +#[async_trait] +impl KeyEncryptor for LocalKey { + #[instrument(skip(self, value), err(Debug))] + async fn encrypt(&self, value: &str) -> Result { + let mut rng = rand::thread_rng(); + Ok(hex::encode(self.public_key.encrypt( + &mut rng, + Pkcs1v15Encrypt, + value.as_bytes(), + )?)) } } -#[derive(Debug)] -pub struct FailedToParseToken {} - -impl TryFrom<&str> for ApiKey { - type Error = FailedToParseToken; - fn try_from(token: &str) -> Result { - BASE64_URL_SAFE_NO_PAD - .decode(&token) - .ok() - .map(|decoded| (token, decoded)) - .and_then(|(token, decoded)| { - tracing::trace!("Decoded token {:?} {:?}", token, decoded); - - if let Some(id) = Uuid::from_slice(&decoded[0..16]).ok() { - Some(ApiKey { - id, - token: token.to_string(), - }) - } else { - tracing::info!("Failed to decode token"); - None +#[instrument(skip(key), err(Debug))] +pub async fn key_to_encryptor( + key: &AsymmetricKey, +) -> Result, EncryptorError> { + let pem = match key { + AsymmetricKey::Local { public, .. } => public.to_string(), + AsymmetricKey::Ckms { + version, + key, + keyring, + location, + project, + .. + } => { + let opts = yup_oauth2::ApplicationDefaultCredentialsFlowOpts::default(); + + tracing::trace!(?opts, "Request GCP credentials"); + + let gcp_credentials = + yup_oauth2::ApplicationDefaultCredentialsAuthenticator::builder(opts).await; + + tracing::trace!("Retrieved GCP credentials"); + + let gcp_auth = match gcp_credentials { + yup_oauth2::authenticator::ApplicationDefaultCredentialsTypes::ServiceAccount( + auth, + ) => { + tracing::debug!("Create GCP service account based credentials"); + + auth.build().await.map_err(|err| { + tracing::error!( + ?err, + "Failed to construct Cloud KMS credentials from service account" + ); + EncryptorError::RemoteKeyAuthMissing + })? } - }) - .ok_or(FailedToParseToken {}) - } + yup_oauth2::authenticator::ApplicationDefaultCredentialsTypes::InstanceMetadata( + auth, + ) => { + tracing::debug!("Create GCP instance based credentials"); + + auth.build().await.map_err(|err| { + tracing::error!( + ?err, + "Failed to construct Cloud KMS credentials from instance metadata" + ); + EncryptorError::RemoteKeyAuthMissing + })? + } + }; + + let gcp_kms = CloudKMS::new( + hyper::Client::builder().build( + hyper_rustls::HttpsConnectorBuilder::new() + .with_native_roots() + .https_only() + .enable_http2() + .build(), + ), + gcp_auth, + ); + + let key_name = format!( + "projects/{}/locations/{}/keyRings/{}/cryptoKeys/{}/cryptoKeyVersions/{}", + project, location, keyring, key, version + ); + let public_key = gcp_kms + .projects() + .locations_key_rings_crypto_keys_crypto_key_versions_get_public_key(&key_name) + .doit() + .await + .map_err(|err| CloudKmsError::from(err))? + .1; + + let pem = public_key.pem.ok_or(CloudKmsError::MissingPem)?; + + pem + } + }; + + Ok(Box::new(LocalKey { + public_key: RsaPublicKey::from_pkcs1_pem(&pem)?, + })) } -#[cfg(test)] -mod tests { - use super::{ApiKey, NewApiKey}; - use dropshot_authorization_header::bearer::BearerAuth; - use uuid::Uuid; +// #[cfg(test)] +// mod tests { +// use super::{ApiKey, NewApiKey}; +// use dropshot_authorization_header::bearer::BearerAuth; +// use uuid::Uuid; - #[test] - fn test_decodes_token() { - let id = Uuid::new_v4(); - let (token, hash) = NewApiKey::generate::<24>(&id).consume(); +// #[test] +// fn test_decodes_token() { +// let id = Uuid::new_v4(); +// let (token, hash) = NewApiKey::generate::<24>(&id).consume(); - let bearer = BearerAuth::new(token); +// let bearer = BearerAuth::new(token); - let authn: ApiKey = bearer.key().unwrap().try_into().unwrap(); +// let authn: ApiKey = bearer.key().unwrap().try_into().unwrap(); - let verified = authn.verify(&hash); +// let verified = authn.verify(&hash); - assert!(verified); - } +// assert!(verified); +// } - #[test] - fn test_fails_to_decode_invalid_token() { - let id = Uuid::new_v4(); - let (token1, _hash1) = NewApiKey::generate::<24>(&id).consume(); - let (_token2, hash2) = NewApiKey::generate::<24>(&id).consume(); +// #[test] +// fn test_fails_to_decode_invalid_token() { +// let id = Uuid::new_v4(); +// let (token1, _hash1) = NewApiKey::generate::<24>(&id).consume(); +// let (_token2, hash2) = NewApiKey::generate::<24>(&id).consume(); - let bearer = BearerAuth::new(token1); +// let bearer = BearerAuth::new(token1); - let authn: ApiKey = bearer.key().unwrap().try_into().unwrap(); +// let authn: ApiKey = bearer.key().unwrap().try_into().unwrap(); - let verified = authn.verify(&hash2); +// let verified = authn.verify(&hash2); - assert!(!verified); - } -} +// assert!(!verified); +// } +// } diff --git a/rfd-api/src/authn/mod.rs b/rfd-api/src/authn/mod.rs index 2a8e208..bc794e0 100644 --- a/rfd-api/src/authn/mod.rs +++ b/rfd-api/src/authn/mod.rs @@ -6,7 +6,7 @@ use thiserror::Error; use crate::{context::ApiContext, util::response::unauthorized}; -use self::{jwt::Jwt, key::ApiKey}; +use self::{jwt::Jwt, key::EncryptedApiKey}; pub mod jwt; pub mod key; @@ -21,7 +21,7 @@ pub enum AuthError { // A token that provides authentication and optionally (JWT) authorization claims pub enum AuthToken { - ApiKey(ApiKey), + ApiKey(EncryptedApiKey), Jwt(Jwt), } @@ -47,18 +47,19 @@ impl AuthToken { // Attempt to decode an API key from the token. If that fails then attempt to verify the // token as a JWT - match ApiKey::try_from(token.as_str()) { - Ok(api_key) => Ok(AuthToken::ApiKey(api_key)), + let jwt = Jwt::new(ctx, &token).await; + + match jwt { + Ok(token) => Ok(AuthToken::Jwt(token)), Err(err) => { - tracing::trace!(?err, "Bearer token is not an api key"); + tracing::debug!(?err, "Token is not a JWT, falling back to API key"); - Jwt::new(ctx, &token) - .await - .map(AuthToken::Jwt) - .map_err(|err| { - tracing::trace!(?err, "Bearer token is not a valid JWT"); + Ok(AuthToken::ApiKey(EncryptedApiKey { + encrypted: ctx.encrypt(token.as_str()).await.map_err(|err| { + tracing::error!(?err, "Failed to encrypt authn token"); AuthError::FailedToExtract - }) + })?, + })) } } } diff --git a/rfd-api/src/config.rs b/rfd-api/src/config.rs index 3d783dc..11d17f7 100644 --- a/rfd-api/src/config.rs +++ b/rfd-api/src/config.rs @@ -70,12 +70,12 @@ pub struct PermissionsConfig { pub struct JwtConfig { pub default_expiration: i64, pub max_expiration: i64, - pub keys: Vec, + pub keys: Vec, } #[derive(Debug, Deserialize)] #[serde(tag = "kind", rename_all = "lowercase")] -pub enum JwtKey { +pub enum AsymmetricKey { Local { kid: String, #[serde(with = "serde_bytes")] diff --git a/rfd-api/src/context.rs b/rfd-api/src/context.rs index 6a27067..6e08c60 100644 --- a/rfd-api/src/context.rs +++ b/rfd-api/src/context.rs @@ -4,15 +4,18 @@ use http::StatusCode; use hyper::{client::HttpConnector, Body, Client}; use hyper_tls::HttpsConnector; use jsonwebtoken::jwk::JwkSet; +use oauth2::CsrfToken; use rfd_model::{ permissions::{Caller, Permissions}, + schema_ext::LoginAttemptState, storage::{ - AccessTokenStore, ApiUserFilter, ApiUserProviderFilter, ApiUserProviderStore, ApiUserStore, - ApiUserTokenFilter, ApiUserTokenStore, JobStore, ListPagination, RfdFilter, RfdPdfFilter, - RfdPdfStore, RfdRevisionFilter, RfdRevisionStore, RfdStore, StoreError, + AccessTokenStore, ApiKeyFilter, ApiKeyStore, ApiUserFilter, ApiUserProviderFilter, + ApiUserProviderStore, ApiUserStore, JobStore, ListPagination, LoginAttemptFilter, + LoginAttemptStore, RfdFilter, RfdPdfFilter, RfdPdfStore, RfdRevisionFilter, + RfdRevisionStore, RfdStore, StoreError, }, - AccessToken, ApiUser, ApiUserProvider, Job, NewAccessToken, NewApiUser, NewApiUserProvider, - NewApiUserToken, NewJob, + AccessToken, ApiUser, ApiUserProvider, InvalidValueError, Job, LoginAttempt, NewAccessToken, + NewApiKey, NewApiUser, NewApiUserProvider, NewJob, NewLoginAttempt, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -25,15 +28,12 @@ use uuid::Uuid; use crate::{ authn::{ jwt::{key_to_signer, Claims, JwtSigner, SignerError}, + key::{key_to_encryptor, EncryptorError, KeyEncryptor}, AuthError, AuthToken, }, config::{JwtConfig, PermissionsConfig}, email_validator::EmailValidator, endpoints::login::{ - access_token::{ - github::GitHubAccessTokenIdentity, AccessTokenProvider, AccessTokenProviderName, - }, - jwt::{google::GoogleOidcIdentity, JwksProvider, JwtProvider, JwtProviderName}, oauth::{OAuthProvider, OAuthProviderError, OAuthProviderFn, OAuthProviderName}, LoginError, UserInfo, }, @@ -49,9 +49,10 @@ pub trait Storage: + RfdPdfStore + JobStore + ApiUserStore - + ApiUserTokenStore + + ApiKeyStore + ApiUserProviderStore + AccessTokenStore + + LoginAttemptStore + Send + Sync + 'static @@ -63,9 +64,10 @@ impl Storage for T where + RfdPdfStore + JobStore + ApiUserStore - + ApiUserTokenStore + + ApiKeyStore + ApiUserProviderStore + AccessTokenStore + + LoginAttemptStore + Send + Sync + 'static @@ -79,8 +81,8 @@ pub struct ApiContext { pub storage: Arc, pub permissions: PermissionsContext, pub jwt: JwtContext, + pub api_key: ApiKeyContext, pub oauth_providers: HashMap>, - pub jwks_providers: HashMap>, } pub struct PermissionsContext { @@ -94,6 +96,10 @@ pub struct JwtContext { pub jwks: JwkSet, } +pub struct ApiKeyContext { + pub encryptor: Box, +} + pub struct RegisteredAccessToken { pub access_token: AccessToken, pub signed_token: String, @@ -121,6 +127,14 @@ impl From for HttpError { } } +#[derive(Debug, Error)] +pub enum LoginAttemptError { + #[error(transparent)] + FailedToCreate(#[from] InvalidValueError), + #[error(transparent)] + Storage(#[from] StoreError), +} + #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)] pub struct FullRfd { pub id: Uuid, @@ -153,8 +167,8 @@ impl ApiContext { ) -> Result { let mut keys = vec![]; - for key in jwt.keys { - keys.push(key_to_signer(&key).await?); + for key in &jwt.keys { + keys.push(key_to_signer(key).await?); } Ok(Self { @@ -173,8 +187,10 @@ impl ApiContext { }, keys, }, + api_key: ApiKeyContext { + encryptor: key_to_encryptor(&jwt.keys[0]).await?, + }, oauth_providers: HashMap::new(), - jwks_providers: HashMap::new(), }) } @@ -199,34 +215,42 @@ impl ApiContext { signer.sign(claims).await } + pub async fn encrypt(&self, value: &str) -> Result { + self.api_key.encryptor.encrypt(value).await + } + #[instrument(skip(self, auth))] pub async fn get_caller(&self, auth: &AuthToken) -> Result { let (api_user_id, permissions) = match auth { AuthToken::ApiKey(api_key) => { async { - let token_id = api_key.id(); - tracing::debug!("Attempt to authenticate"); - let key = ApiUserTokenStore::get(&*self.storage, &token_id, false).await?; - - if let Some(key) = key { - tracing::debug!("Found key to test"); + let mut key = ApiKeyStore::list( + &*self.storage, + ApiKeyFilter { + key: Some(vec![api_key.encrypted.to_string()]), + expired: false, + deleted: false, + ..Default::default() + }, + &ListPagination { + offset: 0, + limit: 1, + }, + ) + .await?; - if api_key.verify(&key.token) { - tracing::debug!("Verified caller key"); + if let Some(key) = key.pop() { + tracing::debug!("Verified caller key"); - Ok((key.api_user_id, key.permissions)) - } else { - tracing::debug!("Failed to verify token"); - Err(CallerError::FailedToAuthenticate) - } + Ok((key.api_user_id, key.permissions)) } else { - tracing::debug!("Failed to find matching token"); + tracing::debug!("Failed to find matching key"); Err(CallerError::FailedToAuthenticate) } } - .instrument(info_span!("Test api key", key_id = ?api_key.id())) + .instrument(info_span!("Test api key")) .await } AuthToken::Jwt(jwt) => { @@ -258,45 +282,15 @@ impl ApiContext { let users = ApiUserStore::list(&*self.storage, user_filter, &ListPagination::latest()).await?; - let mut token_filter = ApiUserTokenFilter::default(); + let mut token_filter = ApiKeyFilter::default(); token_filter.deleted = true; let tokens = - ApiUserTokenStore::list(&*self.storage, token_filter, &ListPagination::latest()) - .await?; + ApiKeyStore::list(&*self.storage, token_filter, &ListPagination::latest()).await?; Ok(users.len() == 0 && tokens.len() == 0) } - pub async fn get_access_token_provider( - &self, - provider: &AccessTokenProviderName, - token: String, - ) -> Result, LoginError> { - match provider { - AccessTokenProviderName::GitHub => Ok(Box::new(GitHubAccessTokenIdentity::new(token)?)), - } - } - - pub fn insert_jwks_provider(&mut self, name: JwtProviderName, provider: Box) { - self.jwks_providers.insert(name, provider); - } - - pub async fn get_jwt_identity<'a>( - &'a self, - provider: &JwtProviderName, - token: String, - ) -> Result, LoginError> { - match provider { - JwtProviderName::Google => Ok(Box::new(GoogleOidcIdentity::new( - token, - self.jwks_providers - .get(provider) - .ok_or(LoginError::InvalidProvider)?, - ))), - } - } - pub fn insert_oauth_provider( &mut self, name: OAuthProviderName, @@ -520,14 +514,14 @@ impl ApiContext { pub async fn create_api_user_token( &self, - token: NewApiUserToken, + token: NewApiKey, api_user: &ApiUser, ) -> Result { - ApiUserTokenStore::upsert(&*self.storage, token, api_user).await + ApiKeyStore::upsert(&*self.storage, token, api_user).await } pub async fn get_api_user_token(&self, id: &Uuid) -> Result, StoreError> { - ApiUserTokenStore::get(&*self.storage, id, false).await + ApiKeyStore::get(&*self.storage, id, false).await } pub async fn get_api_user_tokens( @@ -535,12 +529,13 @@ impl ApiContext { api_user_id: &Uuid, pagination: &ListPagination, ) -> Result, StoreError> { - ApiUserTokenStore::list( + ApiKeyStore::list( &*self.storage, - ApiUserTokenFilter { + ApiKeyFilter { api_user_id: Some(vec![*api_user_id]), expired: true, deleted: false, + ..Default::default() }, pagination, ) @@ -570,14 +565,72 @@ impl ApiContext { } pub async fn delete_api_user_token(&self, id: &Uuid) -> Result, StoreError> { - ApiUserTokenStore::delete(&*self.storage, id).await + ApiKeyStore::delete(&*self.storage, id).await } pub async fn create_access_token( &self, - device_token: NewAccessToken, + access_token: NewAccessToken, ) -> Result { - AccessTokenStore::upsert(&*self.storage, device_token).await + AccessTokenStore::upsert(&*self.storage, access_token).await + } + + pub async fn create_login_attempt( + &self, + attempt: NewLoginAttempt, + ) -> Result { + LoginAttemptStore::upsert(&*self.storage, attempt).await + } + + pub async fn set_login_provider_authz_code( + &self, + attempt: LoginAttempt, + code: String, + ) -> Result { + let mut attempt: NewLoginAttempt = attempt.into(); + attempt.provider_authz_code = Some(code); + + // TODO: Internal state changes to the struct + attempt.attempt_state = LoginAttemptState::RemoteAuthenticated; + attempt.authz_code = Some(CsrfToken::new_random().secret().to_string()); + + LoginAttemptStore::upsert(&*self.storage, attempt).await + } + + pub async fn get_login_attempt(&self, id: &Uuid) -> Result, StoreError> { + LoginAttemptStore::get(&*self.storage, id).await + } + + pub async fn get_login_attempt_for_code( + &self, + code: &str, + ) -> Result, StoreError> { + let filter = LoginAttemptFilter { + attempt_state: Some(vec![LoginAttemptState::RemoteAuthenticated]), + authz_code: Some(vec![code.to_string()]), + ..Default::default() + }; + + let mut attempts = LoginAttemptStore::list( + &*self.storage, + filter, + &ListPagination { + offset: 0, + limit: 1, + }, + ) + .await?; + + Ok(attempts.pop()) + } + + pub async fn fail_login_attempt( + &self, + attempt: LoginAttempt, + ) -> Result { + let mut attempt: NewLoginAttempt = attempt.into(); + attempt.attempt_state = LoginAttemptState::Failed; + LoginAttemptStore::upsert(&*self.storage, attempt).await } } @@ -587,13 +640,14 @@ pub(crate) mod tests { use rfd_model::{ permissions::Caller, storage::{ - AccessTokenStore, ApiUserProviderStore, ApiUserStore, ApiUserTokenStore, JobStore, - ListPagination, MockAccessTokenStore, MockApiUserProviderStore, MockApiUserStore, - MockApiUserTokenStore, MockJobStore, MockRfdPdfStore, MockRfdRevisionStore, - MockRfdStore, RfdPdfStore, RfdRevisionStore, RfdStore, + AccessTokenStore, ApiKeyStore, ApiUserProviderStore, ApiUserStore, JobStore, + ListPagination, LoginAttemptStore, MockAccessTokenStore, MockApiKeyStore, + MockApiUserProviderStore, MockApiUserStore, MockJobStore, MockLoginAttemptStore, + MockRfdPdfStore, MockRfdRevisionStore, MockRfdStore, RfdPdfStore, RfdRevisionStore, + RfdStore, }, - ApiUserProvider, ApiUserToken, NewAccessToken, NewApiUser, NewApiUserProvider, - NewApiUserToken, NewJob, NewRfd, NewRfdPdf, NewRfdRevision, + ApiKey, ApiUserProvider, NewAccessToken, NewApiKey, NewApiUser, NewApiUserProvider, NewJob, + NewLoginAttempt, NewRfd, NewRfdPdf, NewRfdRevision, }; use std::sync::Arc; @@ -607,9 +661,10 @@ pub(crate) mod tests { pub rfd_pdf_store: Option>, pub job_store: Option>, pub api_user_store: Option>>, - pub api_user_token_store: Option>>, + pub api_user_token_store: Option>>, pub api_user_provider_store: Option>, pub device_token_store: Option>, + pub login_attempt_store: Option>, } impl MockStorage { @@ -624,6 +679,7 @@ pub(crate) mod tests { api_user_token_store: None, api_user_provider_store: None, device_token_store: None, + login_attempt_store: None, } } } @@ -824,12 +880,12 @@ pub(crate) mod tests { } #[async_trait] - impl ApiUserTokenStore for MockStorage { + impl ApiKeyStore for MockStorage { async fn get( &self, id: &uuid::Uuid, deleted: bool, - ) -> Result>, rfd_model::storage::StoreError> { + ) -> Result>, rfd_model::storage::StoreError> { self.api_user_token_store .as_ref() .unwrap() @@ -839,9 +895,9 @@ pub(crate) mod tests { async fn list( &self, - filter: rfd_model::storage::ApiUserTokenFilter, + filter: rfd_model::storage::ApiKeyFilter, pagination: &ListPagination, - ) -> Result>, rfd_model::storage::StoreError> { + ) -> Result>, rfd_model::storage::StoreError> { self.api_user_token_store .as_ref() .unwrap() @@ -851,9 +907,9 @@ pub(crate) mod tests { async fn upsert( &self, - token: NewApiUserToken, + token: NewApiKey, api_user: &rfd_model::ApiUser, - ) -> Result, rfd_model::storage::StoreError> { + ) -> Result, rfd_model::storage::StoreError> { self.api_user_token_store .as_ref() .unwrap() @@ -864,7 +920,7 @@ pub(crate) mod tests { async fn delete( &self, id: &uuid::Uuid, - ) -> Result>, rfd_model::storage::StoreError> { + ) -> Result>, rfd_model::storage::StoreError> { self.api_user_token_store.as_ref().unwrap().delete(id).await } } @@ -955,4 +1011,37 @@ pub(crate) mod tests { .await } } + + #[async_trait] + impl LoginAttemptStore for MockStorage { + async fn get( + &self, + id: &uuid::Uuid, + ) -> Result, rfd_model::storage::StoreError> { + self.login_attempt_store.as_ref().unwrap().get(id).await + } + + async fn list( + &self, + filter: rfd_model::storage::LoginAttemptFilter, + pagination: &ListPagination, + ) -> Result, rfd_model::storage::StoreError> { + self.login_attempt_store + .as_ref() + .unwrap() + .list(filter, pagination) + .await + } + + async fn upsert( + &self, + attempt: NewLoginAttempt, + ) -> Result { + self.login_attempt_store + .as_ref() + .unwrap() + .upsert(attempt) + .await + } + } } diff --git a/rfd-api/src/endpoints/api_user.rs b/rfd-api/src/endpoints/api_user.rs index 888be8b..f54b1b4 100644 --- a/rfd-api/src/endpoints/api_user.rs +++ b/rfd-api/src/endpoints/api_user.rs @@ -4,7 +4,7 @@ use dropshot::{ }; use http::StatusCode; use partial_struct::partial; -use rfd_model::{storage::ListPagination, ApiUser, NewApiUser, NewApiUserToken}; +use rfd_model::{storage::ListPagination, ApiUser, NewApiKey, NewApiUser}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use trace_request::trace_request; @@ -12,11 +12,11 @@ use tracing::instrument; use uuid::Uuid; use crate::{ - authn::key::NewApiKey, + authn::key::RawApiKey, context::ApiContext, error::ApiError, permissions::ApiPermission, - util::response::{client_error, not_found}, + util::response::{client_error, not_found, to_internal_error}, ApiCaller, ApiPermissions, User, }; @@ -185,7 +185,7 @@ async fn update_api_user_op( pub async fn list_api_user_tokens( rqctx: RequestContext, path: Path, -) -> Result>, HttpError> { +) -> Result>, HttpError> { let ctx = rqctx.context(); let auth = ctx.authn_token(&rqctx).await?; let caller = ctx.get_caller(&auth).await?; @@ -197,7 +197,7 @@ async fn list_api_user_tokens_op( ctx: &ApiContext, caller: &ApiCaller, path: &ApiUserPath, -) -> Result>, HttpError> { +) -> Result>, HttpError> { if caller.can(&ApiPermission::GetApiUserToken(path.identifier).into()) { tracing::info!("Fetch token list"); @@ -211,7 +211,7 @@ async fn list_api_user_tokens_op( Ok(HttpResponseOk( tokens .into_iter() - .map(|token| ApiUserTokenResponse { + .map(|token| ApiKeyResponse { id: token.id, permissions: token.permissions, created_at: token.created_at, @@ -224,17 +224,17 @@ async fn list_api_user_tokens_op( } #[derive(Debug, Clone, Deserialize, JsonSchema)] -pub struct ApiUserTokenCreateParams { +pub struct ApiKeyCreateParams { permissions: ApiPermissions, expires_at: DateTime, } -#[partial(ApiUserTokenResponse)] +#[partial(ApiKeyResponse)] #[derive(Debug, Serialize, JsonSchema)] -pub struct InitialApiUserTokenResponse { +pub struct InitialApiKeyResponse { pub id: Uuid, - #[partial(ApiUserTokenResponse(skip))] - pub token: String, + #[partial(ApiKeyResponse(skip))] + pub key: String, pub permissions: ApiPermissions, pub created_at: DateTime, } @@ -250,8 +250,8 @@ pub struct InitialApiUserTokenResponse { pub async fn create_api_user_token( rqctx: RequestContext, path: Path, - body: TypedBody, -) -> Result, HttpError> { + body: TypedBody, +) -> Result, HttpError> { let ctx = rqctx.context(); let auth = ctx.authn_token(&rqctx).await?; let caller = ctx.get_caller(&auth).await?; @@ -263,8 +263,8 @@ async fn create_api_user_token_op( ctx: &ApiContext, caller: &ApiCaller, path: &ApiUserPath, - body: ApiUserTokenCreateParams, -) -> Result, HttpError> { + body: ApiKeyCreateParams, +) -> Result, HttpError> { if caller.can(&ApiPermission::CreateApiUserToken(path.identifier).into()) { let api_user = ctx .get_api_user(&path.identifier) @@ -272,16 +272,20 @@ async fn create_api_user_token_op( .map_err(ApiError::Storage)?; if let Some(api_user) = api_user { - let token_id = Uuid::new_v4(); + let key_id = Uuid::new_v4(); - let (token, hash) = NewApiKey::generate::<24>(&token_id).consume(); + let key = RawApiKey::generate::<24>(); + let encrypted = key + .encrypt(&*ctx.api_key.encryptor) + .await + .map_err(to_internal_error)?; - let user_token = ctx + let user_key = ctx .create_api_user_token( - NewApiUserToken { - id: token_id, + NewApiKey { + id: key_id, api_user_id: path.identifier, - token: hash, + key: encrypted.encrypted, permissions: body.permissions, expires_at: body.expires_at, }, @@ -292,11 +296,11 @@ async fn create_api_user_token_op( // Creating an api token will return the hashed version, but we need to return the // plaintext token as we do not store a copy - Ok(HttpResponseCreated(InitialApiUserTokenResponse { - id: user_token.id, - token, - permissions: user_token.permissions, - created_at: user_token.created_at, + Ok(HttpResponseCreated(InitialApiKeyResponse { + id: user_key.id, + key: key.consume(), + permissions: user_key.permissions, + created_at: user_key.created_at, })) } else { Err(not_found("Failed to find api user")) @@ -322,7 +326,7 @@ pub struct ApiUserTokenPath { pub async fn get_api_user_token( rqctx: RequestContext, path: Path, -) -> Result, HttpError> { +) -> Result, HttpError> { let ctx = rqctx.context(); let auth = ctx.authn_token(&rqctx).await?; let caller = ctx.get_caller(&auth).await?; @@ -334,7 +338,7 @@ async fn get_api_user_token_op( ctx: &ApiContext, caller: &ApiCaller, path: &ApiUserTokenPath, -) -> Result, HttpError> { +) -> Result, HttpError> { if caller.can(&ApiPermission::GetApiUserToken(path.identifier).into()) { let token = ctx .get_api_user_token(&path.token_identifier) @@ -342,7 +346,7 @@ async fn get_api_user_token_op( .map_err(ApiError::Storage)?; if let Some(token) = token { - Ok(HttpResponseOk(ApiUserTokenResponse { + Ok(HttpResponseOk(ApiKeyResponse { id: token.id, permissions: token.permissions, created_at: token.created_at, @@ -365,7 +369,7 @@ async fn get_api_user_token_op( pub async fn delete_api_user_token( rqctx: RequestContext, path: Path, -) -> Result, HttpError> { +) -> Result, HttpError> { let ctx = rqctx.context(); let auth = ctx.authn_token(&rqctx).await?; let caller = ctx.get_caller(&auth).await?; @@ -377,7 +381,7 @@ async fn delete_api_user_token_op( ctx: &ApiContext, caller: &ApiCaller, path: &ApiUserTokenPath, -) -> Result, HttpError> { +) -> Result, HttpError> { if caller.can(&ApiPermission::DeleteApiUserToken(path.identifier).into()) { let token = ctx .delete_api_user_token(&path.token_identifier) @@ -385,7 +389,7 @@ async fn delete_api_user_token_op( .map_err(ApiError::Storage)?; if let Some(token) = token { - Ok(HttpResponseOk(ApiUserTokenResponse { + Ok(HttpResponseOk(ApiKeyResponse { id: token.id, permissions: token.permissions, created_at: token.created_at, @@ -406,10 +410,8 @@ mod tests { use http::StatusCode; use mockall::predicate::eq; use rfd_model::{ - storage::{ - ApiUserTokenFilter, ListPagination, MockApiUserStore, MockApiUserTokenStore, StoreError, - }, - ApiUser, ApiUserToken, NewApiUser, + storage::{ApiKeyFilter, ListPagination, MockApiKeyStore, MockApiUserStore, StoreError}, + ApiKey, ApiUser, NewApiUser, }; use uuid::Uuid; @@ -418,7 +420,7 @@ mod tests { context::{tests::MockStorage, ApiContext}, endpoints::api_user::{ create_api_user_token_op, delete_api_user_token_op, get_api_user_token_op, - list_api_user_tokens_op, update_api_user_op, ApiUserPath, ApiUserTokenCreateParams, + list_api_user_tokens_op, update_api_user_op, ApiKeyCreateParams, ApiUserPath, ApiUserTokenPath, }, permissions::ApiPermission, @@ -621,10 +623,10 @@ mod tests { let success_id = Uuid::new_v4(); let failure_id = Uuid::new_v4(); - let mut store = MockApiUserTokenStore::new(); + let mut store = MockApiKeyStore::new(); store .expect_list() - .withf(move |x: &ApiUserTokenFilter, _: &ListPagination| { + .withf(move |x: &ApiKeyFilter, _: &ListPagination| { x.api_user_id .as_ref() .map(|id| id.contains(&success_id)) @@ -633,7 +635,7 @@ mod tests { .returning(|_, _| Ok(vec![])); store .expect_list() - .withf(move |x: &ApiUserTokenFilter, _: &ListPagination| { + .withf(move |x: &ApiKeyFilter, _: &ListPagination| { x.api_user_id .as_ref() .map(|id| id.contains(&failure_id)) @@ -751,7 +753,7 @@ mod tests { identifier: Uuid::new_v4(), }; - let new_token = ApiUserTokenCreateParams { + let new_token = ApiKeyCreateParams { permissions: Vec::new().into(), expires_at: Utc::now() + Duration::seconds(5 * 60), }; @@ -770,17 +772,17 @@ mod tests { .with(eq(unknown_api_user_path.identifier), eq(false)) .returning(move |_, _| Ok(None)); - let mut token_store = MockApiUserTokenStore::new(); + let mut token_store = MockApiKeyStore::new(); token_store .expect_upsert() .withf(move |_, user| user.id == api_user_id) - .returning(|token, user| { - Ok(ApiUserToken { + .returning(|key, user| { + Ok(ApiKey { id: Uuid::new_v4(), api_user_id: user.id, - token: token.token, - permissions: token.permissions, - expires_at: token.expires_at, + key: key.key, + permissions: key.permissions, + expires_at: key.expires_at, created_at: Utc::now(), updated_at: Utc::now(), deleted_at: None, @@ -895,10 +897,10 @@ mod tests { async fn test_get_api_user_token_permissions() { let api_user_id = Uuid::new_v4(); - let token = ApiUserToken { + let token = ApiKey { id: Uuid::new_v4(), api_user_id: api_user_id, - token: "hashed_token".to_string(), + key: "encrypted_key".to_string(), permissions: Vec::new().into(), expires_at: Utc::now() + Duration::seconds(5 * 60), created_at: Utc::now(), @@ -921,7 +923,7 @@ mod tests { token_identifier: Uuid::new_v4(), }; - let mut token_store = MockApiUserTokenStore::new(); + let mut token_store = MockApiKeyStore::new(); token_store .expect_get() .with(eq(api_user_token_path.token_identifier), eq(false)) @@ -1021,10 +1023,10 @@ mod tests { async fn test_delete_api_user_token_permissions() { let api_user_id = Uuid::new_v4(); - let token = ApiUserToken { + let token = ApiKey { id: Uuid::new_v4(), api_user_id: api_user_id, - token: "hashed_token".to_string(), + key: "encrypted_key".to_string(), permissions: Vec::new().into(), expires_at: Utc::now() + Duration::seconds(5 * 60), created_at: Utc::now(), @@ -1047,7 +1049,7 @@ mod tests { token_identifier: Uuid::new_v4(), }; - let mut token_store = MockApiUserTokenStore::new(); + let mut token_store = MockApiKeyStore::new(); token_store .expect_delete() .with(eq(api_user_token_path.token_identifier)) diff --git a/rfd-api/src/endpoints/login/mod.rs b/rfd-api/src/endpoints/login/mod.rs index 7934d25..767f5f7 100644 --- a/rfd-api/src/endpoints/login/mod.rs +++ b/rfd-api/src/endpoints/login/mod.rs @@ -13,10 +13,10 @@ use crate::{ util::response::{bad_request, internal_error}, }; -use self::access_token::AccessTokenError; +// use self::access_token::AccessTokenError; -pub mod access_token; -pub mod jwt; +// pub mod access_token; +// pub mod jwt; pub mod oauth; #[derive(Debug, Deserialize, Serialize, JsonSchema)] @@ -27,8 +27,8 @@ pub enum LoginPermissions { #[derive(Debug, Error)] pub enum LoginError { - #[error(transparent)] - AccessTokenError(#[from] AccessTokenError), + // #[error(transparent)] + // AccessTokenError(#[from] AccessTokenError), #[error("Requested token lifetime exceeds maximum configuration duration")] ExcessTokenExpiration, #[error("Failed to parse access token {0}")] @@ -42,9 +42,9 @@ pub enum LoginError { impl From for HttpError { fn from(err: LoginError) -> Self { match err { - LoginError::AccessTokenError(_) => { - internal_error("Failed to construct internal client to authenticate") - } + // LoginError::AccessTokenError(_) => { + // internal_error("Failed to construct internal client to authenticate") + // } LoginError::ExcessTokenExpiration => { let mut err = bad_request("Requested expiration exceeds maximum allowed token duration"); diff --git a/rfd-api/src/endpoints/login/oauth/authz_code.rs b/rfd-api/src/endpoints/login/oauth/authz_code.rs new file mode 100644 index 0000000..e8dbd1c --- /dev/null +++ b/rfd-api/src/endpoints/login/oauth/authz_code.rs @@ -0,0 +1,272 @@ +use chrono::{Duration, Utc}; +use dropshot::{ + endpoint, http_response_temporary_redirect, HttpError, HttpResponseOk, + HttpResponseTemporaryRedirect, Path, Query, RequestContext, +}; +use oauth2::{ + reqwest::async_http_client, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, + Scope, TokenResponse, +}; +use rfd_model::{schema_ext::LoginAttemptState, NewLoginAttempt}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use std::ops::Add; +use tap::TapFallible; +use tracing::instrument; +use uuid::Uuid; + +use super::{OAuthProviderNameParam, UserInfoProvider}; +use crate::{ + context::ApiContext, + endpoints::login::LoginError, + error::ApiError, + util::response::{bad_request, internal_error, to_internal_error}, +}; + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeQuery { + pub client_id: Uuid, + pub redirect_uri: String, + pub response_type: String, + pub state: String, +} + +/// Generate the remote provider login url and redirect the user +#[endpoint { + method = GET, + path = "/login/oauth/{provider}/authz_code/authorize" +}] +#[instrument(skip(rqctx), fields(request_id = rqctx.request_id), err(Debug))] +pub async fn authz_code_redirect( + rqctx: RequestContext, + path: Path, + query: Query, +) -> Result { + let ctx = rqctx.context(); + let path = path.into_inner(); + let query = query.into_inner(); + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code login"); + + let client = provider.as_client().map_err(to_internal_error)?; + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + // Construct a new login attempt with the minimum required values + let mut attempt = NewLoginAttempt::new( + query.client_id, + query.redirect_uri, + provider.name().to_string(), + CsrfToken::new_random().secret().to_string(), + pkce_verifier.secret().to_string(), + ) + .map_err(|err| { + tracing::error!(?err, "Attempted to construct invalid login attempt"); + internal_error("Attempted to construct invalid login attempt".to_string()) + })?; + + // Add in the user defined state and redirect uri + attempt.state = Some(query.state); + + // Store the generate attempt + let attempt = ctx + .create_login_attempt(attempt) + .await + .map_err(to_internal_error)?; + + // Generate the url to the remote provider that the user will be redirected to + let (url, _) = client + .authorize_url(|| CsrfToken::new(format!("{}:{}", attempt.id, attempt.provider_state))) + .set_pkce_challenge(pkce_challenge) + .add_scopes( + provider + .scopes() + .into_iter() + .map(|s| Scope::new(s.to_string())) + .collect::>(), + ) + .url(); + + http_response_temporary_redirect(url.to_string()) +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeReturnQuery { + pub state: String, + pub code: String, +} + +/// Handle return calls from a remote OAuth provider +#[endpoint { + method = GET, + path = "/login/oauth/{provider}/authz_code/return" +}] +#[instrument(skip(rqctx), fields(request_id = rqctx.request_id), err(Debug))] +pub async fn authz_code_return( + rqctx: RequestContext, + path: Path, + query: Query, +) -> Result { + let ctx = rqctx.context(); + let path = path.into_inner(); + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + let query = query.into_inner(); + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code exchange"); + + // Attempt to extract the request id and csrf token from the state parameter. These must both + // be present + let (id, csrf) = query + .state + .split_once(":") + .and_then(|(id, csrf)| id.parse::().ok().map(|id| (id, csrf))) + .ok_or_else(|| bad_request("Invalid state".to_string()))?; + + // Look up the login attempt referenced in the state and verify that has the csrf value still + // matches + let original_attempt = ctx + .get_login_attempt(&id) + .await + .map_err(to_internal_error)? + .ok_or_else(|| bad_request("Invalid login attempt".to_string()))?; + + // Verify the csrf token. If these do not match, then fail the request + if original_attempt.provider_state != csrf { + ctx.fail_login_attempt(original_attempt) + .await + .map_err(to_internal_error)?; + return Err(bad_request("Invalid csrf".to_string())); + } + + // Store the authorization code returned by the underlying OAuth provider and transition the + // attempt to the awaiting state + let attempt = ctx + .set_login_provider_authz_code(original_attempt, query.code) + .await + .map_err(to_internal_error)?; + + // Redirect back to the original authenticator + http_response_temporary_redirect(attempt.callback_url()) +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeExchangeQuery { + pub client_id: Uuid, + pub redirect_uri: String, + pub grant_type: String, + pub code: String, + pub pkce_verifier: Option, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeExchangeResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: i64, +} + +/// Exchange an authorization code for an access token +#[endpoint { + method = GET, + path = "/login/oauth/{provider}/authz_code/exchange" +}] +#[instrument(skip(rqctx), fields(request_id = rqctx.request_id), err(Debug))] +pub async fn authz_code_exchange( + rqctx: RequestContext, + path: Path, + query: Query, +) -> Result, HttpError> { + let ctx = rqctx.context(); + let path = path.into_inner(); + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + let query = query.into_inner(); + + // Verify that we received the expected grant type + if &query.grant_type != "authorization_code" { + return Err(bad_request("Invalid grant type")); + } + + // Lookup the request assigned to this code and verify that it is a valid request + let attempt = ctx + .get_login_attempt_for_code(&query.code) + .await + .map_err(to_internal_error)? + .ok_or_else(|| bad_request("Invalid code".to_string())) + .and_then(|attempt| { + if attempt.client_id != query.client_id { + Err(bad_request("Invalid client id".to_string())) + } else if attempt.redirect_uri != query.redirect_uri { + Err(bad_request("Invalid redirect uri".to_string())) + } else if attempt.attempt_state != LoginAttemptState::RemoteAuthenticated { + Err(bad_request("Invalid login state".to_string())) + } else if attempt.expires_at.map(|t| t <= Utc::now()).unwrap_or(true) { + Err(bad_request("Login attempt expired".to_string())) + } else { + // TODO: Perform pkce check + + Ok(attempt) + } + })?; + + // Exchange the stored authorization code with the remote provider for a remote access token + let client = provider.as_client().map_err(to_internal_error)?; + let response = client + .exchange_code(AuthorizationCode::new( + attempt.provider_authz_code.ok_or_else(|| { + internal_error("Expected authorization code to exist due to attempt state") + })?, + )) + .set_pkce_verifier(PkceCodeVerifier::new(attempt.provider_pkce_verifier)) + .request_async(async_http_client) + .await + .map_err(to_internal_error)?; + + // Use the retrieved access token to fetch the user information from the remote API + let info = provider + .get_user_info(&ctx.https_client, response.access_token().secret()) + .await + .map_err(LoginError::UserInfo) + .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; + + tracing::debug!("Verified and validated OAuth user"); + + // Register this user as an API user if needed + let api_user = ctx.register_api_user(info).await?; + + tracing::info!(api_user_id = ?api_user.id, "Retrieved api user to generate access token for"); + + // Generate a new access token for the user with an expiration matching the value given to us + // by the remote service + let token = ctx + .register_access_token( + &api_user, + &api_user.permissions, + Some( + Utc::now().add(Duration::seconds( + response + .expires_in() + .map(|d| d.as_secs() - 120) + .unwrap_or(0) as i64, + )), + ), + ) + .await?; + + tracing::info!(provider = ?path.provider, api_user_id = ?api_user.id, "Generated access token"); + + Ok(HttpResponseOk(OAuthAuthzCodeExchangeResponse { + token_type: "Bearer".to_string(), + access_token: token.signed_token, + expires_in: token.expires_at.timestamp() - Utc::now().timestamp(), + })) +} diff --git a/rfd-api/src/endpoints/login/oauth/device_token.rs b/rfd-api/src/endpoints/login/oauth/device_token.rs new file mode 100644 index 0000000..b1d7baa --- /dev/null +++ b/rfd-api/src/endpoints/login/oauth/device_token.rs @@ -0,0 +1,193 @@ +use chrono::{DateTime, Utc}; +use dropshot::{endpoint, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; +use http::{header, Request, Response, StatusCode}; +use hyper::{body::to_bytes, Body}; +use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse, TokenResponse}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use tap::TapFallible; +use trace_request::trace_request; +use tracing::instrument; + +use super::{OAuthProvider, OAuthProviderInfo, OAuthProviderNameParam, UserInfoProvider}; +use crate::{ + context::ApiContext, endpoints::login::LoginError, error::ApiError, util::response::bad_request, +}; + +// Get the metadata about an OAuth provider necessary to begin a device code exchange +#[trace_request] +#[endpoint { + method = GET, + path = "/login/oauth/{provider}/device" +}] +#[instrument(skip(rqctx), fields(request_id = rqctx.request_id), err(Debug))] +pub async fn get_device_provider( + rqctx: RequestContext, + path: Path, +) -> Result, HttpError> { + let path = path.into_inner(); + + let provider = rqctx + .context() + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + Ok(HttpResponseOk( + provider.provider_info(&rqctx.context().public_url), + )) +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct AccessTokenExchangeRequest { + pub device_code: String, + pub grant_type: String, + pub expires_at: Option>, +} + +#[derive(Serialize)] +pub struct AccessTokenExchange { + provider: ProviderTokenExchange, + expires_at: Option>, +} + +#[derive(Serialize)] +pub struct ProviderTokenExchange { + client_id: String, + device_code: String, + grant_type: String, + client_secret: String, +} + +impl AccessTokenExchange { + pub fn new( + req: AccessTokenExchangeRequest, + provider: &Box, + ) -> Option { + provider.client_secret().map(|client_secret| Self { + provider: ProviderTokenExchange { + client_id: provider.client_id().to_string(), + device_code: req.device_code, + grant_type: req.grant_type, + client_secret: client_secret.to_string(), + }, + expires_at: req.expires_at, + }) + } +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct ProxyTokenResponse { + access_token: String, + token_type: String, + expires_in: Option, + refresh_token: Option, + scopes: Option>, +} + +// Complete a device exchange request against the specified provider. This effectively proxies the +// requests that would go to the provider, captures the returned access tokens, and registers a +// new internal user as needed. The user is then returned an token that is valid for interacting +// with the RFD API +#[endpoint { + method = POST, + path = "/login/oauth/{provider}/device/exchange", + content_type = "application/x-www-form-urlencoded", +}] +#[instrument(skip(rqctx, body), fields(request_id = rqctx.request_id), err(Debug))] +pub async fn exchange_device_token( + rqctx: RequestContext, + path: Path, + body: TypedBody, +) -> Result, HttpError> { + let ctx = rqctx.context(); + let path = path.into_inner(); + let mut provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); + + let exchange_request = body.into_inner(); + + if let Some(mut exchange) = AccessTokenExchange::new(exchange_request, &mut provider) { + exchange.provider.client_secret = exchange.provider.client_secret; + + let token_exchange_endpoint = provider.token_exchange_endpoint(); + + // We know that this is safe to unwrap as we just deserialized it via the body Extractor + let body: Body = serde_urlencoded::to_string(&exchange.provider) + .unwrap() + .into(); + + let request = Request::builder() + .method(Method::POST) + .header(header::CONTENT_TYPE, provider.token_exchange_content_type()) + .uri(token_exchange_endpoint) + .body(body) + .tap_err(|err| tracing::error!(?err, "Failed to construct token exchange request"))?; + + let response = ctx + .https_client + .request(request) + .await + .tap_err(|err| tracing::error!(?err, "Token exchange request failed"))?; + + if response.status().is_success() { + tracing::debug!("Successfully exchanged token with provider"); + + let (_, body) = response.into_parts(); + let bytes = to_bytes(body).await?; + let parsed: StandardTokenResponse = + serde_json::from_slice(&bytes).map_err(LoginError::FailedToParseToken)?; + + let info = provider + .get_user_info(&ctx.https_client, parsed.access_token().secret()) + .await + .map_err(LoginError::UserInfo) + .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; + + tracing::debug!("Verified and validated OAuth user"); + + let api_user = ctx.register_api_user(info).await?; + + tracing::info!(api_user_id = ?api_user.id, "Retrieved api user to generate device token for"); + + let token = ctx + .register_access_token(&api_user, &api_user.permissions, exchange.expires_at) + .await?; + + tracing::info!(provider = ?path.provider, api_user_id = ?api_user.id, "Generated access token"); + + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_string(&ProxyTokenResponse { + access_token: token.signed_token, + token_type: "Bearer".to_string(), + expires_in: Some( + (token.expires_at - Utc::now()) + .num_seconds() + .try_into() + .unwrap_or(0), + ), + refresh_token: None, + scopes: None, + }) + .unwrap() + .into(), + )?) + } else { + tracing::warn!(provider = ?path.provider, "Received error response from OAuth provider"); + + Ok(response) + } + } else { + tracing::info!(provider = ?path.provider, "Found an OAuth provider, but it is not configured properly"); + + Err(bad_request("Invalid provider")) + } +} diff --git a/rfd-api/src/endpoints/login/oauth/google.rs b/rfd-api/src/endpoints/login/oauth/google.rs index 1fa2bff..5a4ac85 100644 --- a/rfd-api/src/endpoints/login/oauth/google.rs +++ b/rfd-api/src/endpoints/login/oauth/google.rs @@ -71,8 +71,10 @@ impl OAuthProvider for GoogleOAuthProvider { &self.public.client_id } - fn client_secret(&mut self) -> Option { - self.private.take().map(|private| private.client_secret) + fn client_secret(&self) -> Option<&str> { + self.private + .as_ref() + .map(|private| private.client_secret.as_str()) } fn user_info_endpoint(&self) -> &str { diff --git a/rfd-api/src/endpoints/login/oauth/mod.rs b/rfd-api/src/endpoints/login/oauth/mod.rs index 1b978f8..79981c6 100644 --- a/rfd-api/src/endpoints/login/oauth/mod.rs +++ b/rfd-api/src/endpoints/login/oauth/mod.rs @@ -1,22 +1,18 @@ use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use dropshot::{endpoint, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; -use http::{header, Request, Response, StatusCode}; +use dropshot::Method; +use http::{header, Request}; use hyper::{body::to_bytes, client::connect::Connect, Body, Client}; -use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse, TokenResponse}; +use oauth2::{basic::BasicClient, url::ParseError, AuthUrl, ClientId, ClientSecret, TokenUrl}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display}; -use tap::TapFallible; use thiserror::Error; -use trace_request::trace_request; use tracing::instrument; use super::{UserInfo, UserInfoError, UserInfoProvider}; -use crate::{ - context::ApiContext, endpoints::login::LoginError, error::ApiError, util::response::bad_request, -}; +pub mod authz_code; +pub mod device_token; pub mod google; #[derive(Debug, Error)] @@ -29,7 +25,7 @@ pub trait OAuthProvider: ExtractUserInfo + Debug { fn name(&self) -> OAuthProviderName; fn scopes(&self) -> Vec<&str>; fn client_id(&self) -> &str; - fn client_secret(&mut self) -> Option; + fn client_secret(&self) -> Option<&str>; fn user_info_endpoint(&self) -> &str; fn device_code_endpoint(&self) -> &str; fn auth_url_endpoint(&self) -> &str; @@ -50,6 +46,16 @@ pub trait OAuthProvider: ExtractUserInfo + Debug { .collect::>(), } } + + fn as_client(&self) -> Result { + Ok(BasicClient::new( + ClientId::new(self.client_id().to_string()), + self.client_secret() + .map(|s| ClientSecret::new(s.to_string())), + AuthUrl::new(self.auth_url_endpoint().to_string())?, + Some(TokenUrl::new(self.token_exchange_endpoint().to_string())?), + )) + } } pub trait ExtractUserInfo { @@ -124,180 +130,3 @@ impl Display for OAuthProviderName { pub struct OAuthProviderNameParam { provider: OAuthProviderName, } - -// Get the metadata about an OAuth provider necessary to begin a device code exchange -#[trace_request] -#[endpoint { - method = GET, - path = "/login/oauth/{provider}/device" -}] -#[instrument(skip(rqctx), fields(request_id = rqctx.request_id), err(Debug))] -pub async fn get_device_provider( - rqctx: RequestContext, - path: Path, -) -> Result, HttpError> { - let path = path.into_inner(); - - let provider = rqctx - .context() - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - Ok(HttpResponseOk( - provider.provider_info(&rqctx.context().public_url), - )) -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct AccessTokenExchangeRequest { - pub device_code: String, - pub grant_type: String, - pub expires_at: Option>, -} - -#[derive(Serialize)] -pub struct AccessTokenExchange { - provider: ProviderTokenExchange, - expires_at: Option>, -} - -#[derive(Serialize)] -pub struct ProviderTokenExchange { - client_id: String, - device_code: String, - grant_type: String, - client_secret: String, -} - -impl AccessTokenExchange { - pub fn new( - req: AccessTokenExchangeRequest, - provider: &mut Box, - ) -> Option { - provider.client_secret().map(|client_secret| Self { - provider: ProviderTokenExchange { - client_id: provider.client_id().to_string(), - device_code: req.device_code, - grant_type: req.grant_type, - client_secret, - }, - expires_at: req.expires_at, - }) - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct ProxyTokenResponse { - access_token: String, - token_type: String, - expires_in: Option, - refresh_token: Option, - scopes: Option>, -} - -// Complete a device exchange request against the specified provider. This effectively proxies the -// requests that would go to the provider, captures the returned access tokens, and registers a -// new internal user as needed. The user is then returned an token that is valid for interacting -// with the RFD API -#[endpoint { - method = POST, - path = "/login/oauth/{provider}/device/exchange", - content_type = "application/x-www-form-urlencoded", -}] -#[instrument(skip(rqctx, body), fields(request_id = rqctx.request_id), err(Debug))] -pub async fn exchange_device_token( - rqctx: RequestContext, - path: Path, - body: TypedBody, -) -> Result, HttpError> { - let ctx = rqctx.context(); - let path = path.into_inner(); - let mut provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); - - let exchange_request = body.into_inner(); - - if let Some(mut exchange) = AccessTokenExchange::new(exchange_request, &mut provider) { - exchange.provider.client_secret = exchange.provider.client_secret; - - let token_exchange_endpoint = provider.token_exchange_endpoint(); - - // We know that this is safe to unwrap as we just deserialized it via the body Extractor - let body: Body = serde_urlencoded::to_string(&exchange.provider) - .unwrap() - .into(); - - let request = Request::builder() - .method(Method::POST) - .header(header::CONTENT_TYPE, provider.token_exchange_content_type()) - .uri(token_exchange_endpoint) - .body(body) - .tap_err(|err| tracing::error!(?err, "Failed to construct token exchange request"))?; - - let response = ctx - .https_client - .request(request) - .await - .tap_err(|err| tracing::error!(?err, "Token exchange request failed"))?; - - if response.status().is_success() { - tracing::debug!("Successfully exchanged token with provider"); - - let (_, body) = response.into_parts(); - let bytes = to_bytes(body).await?; - let parsed: StandardTokenResponse = - serde_json::from_slice(&bytes).map_err(LoginError::FailedToParseToken)?; - - let info = provider - .get_user_info(&ctx.https_client, parsed.access_token().secret()) - .await - .map_err(LoginError::UserInfo) - .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; - - tracing::debug!("Verified and validated OAuth user"); - - let api_user = ctx.register_api_user(info).await?; - - tracing::info!(api_user_id = ?api_user.id, "Retrieved api user to generate device token for"); - - let token = ctx - .register_access_token(&api_user, &api_user.permissions, exchange.expires_at) - .await?; - - tracing::info!(provider = ?path.provider, api_user_id = ?api_user.id, "Generated access token"); - - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/json") - .body( - serde_json::to_string(&ProxyTokenResponse { - access_token: token.signed_token, - token_type: "Bearer".to_string(), - expires_in: Some( - (token.expires_at - Utc::now()) - .num_seconds() - .try_into() - .unwrap_or(0), - ), - refresh_token: None, - scopes: None, - }) - .unwrap() - .into(), - )?) - } else { - tracing::warn!(provider = ?path.provider, "Received error response from OAuth provider"); - - Ok(response) - } - } else { - tracing::info!(provider = ?path.provider, "Found an OAuth provider, but it is not configured properly"); - - Err(bad_request("Invalid provider")) - } -} diff --git a/rfd-api/src/error.rs b/rfd-api/src/error.rs index 569711c..adab750 100644 --- a/rfd-api/src/error.rs +++ b/rfd-api/src/error.rs @@ -3,13 +3,15 @@ use rfd_model::storage::StoreError; use thiserror::Error; use crate::{ - authn::jwt::SignerError, + authn::{jwt::SignerError, key::EncryptorError}, endpoints::login::{oauth::OAuthProviderError, LoginError}, util::response::internal_error, }; #[derive(Debug, Error)] pub enum AppError { + #[error(transparent)] + EncryptorError(#[from] EncryptorError), #[error("At least one JWT signing key must be configured")] NoConfiguredJwtKeys, #[error(transparent)] diff --git a/rfd-api/src/main.rs b/rfd-api/src/main.rs index b689b16..8371dcd 100644 --- a/rfd-api/src/main.rs +++ b/rfd-api/src/main.rs @@ -3,7 +3,7 @@ use permissions::ApiPermission; use rfd_model::{ permissions::{Caller, Permissions}, storage::postgres::PostgresStore, - ApiUser, ApiUserToken, + ApiKey, ApiUser, }; use server::{server, ServerConfig}; use std::{ @@ -16,10 +16,7 @@ use tracing_subscriber::EnvFilter; use crate::{ config::{AppConfig, ServerLogFormat}, email_validator::DomainValidator, - endpoints::login::{ - jwt::{google::GoogleOidcJwks, JwtProviderName}, - oauth::{google::GoogleOAuthProvider, OAuthProviderName}, - }, + endpoints::login::oauth::{google::GoogleOAuthProvider, OAuthProviderName}, }; mod authn; @@ -35,7 +32,7 @@ mod util; pub type ApiCaller = Caller; pub type ApiPermissions = Permissions; pub type User = ApiUser; -pub type UserToken = ApiUserToken; +pub type UserToken = ApiKey; #[tokio::main] async fn main() -> Result<(), Box> { @@ -78,12 +75,12 @@ async fn main() -> Result<(), Box> { ) } - if let Some(google) = config.authn.jwt.google { - context.insert_jwks_provider( - JwtProviderName::Google, - Box::new(GoogleOidcJwks::new(google.issuer, google.well_known_uri)), - ) - } + // if let Some(google) = config.authn.jwt.google { + // context.insert_jwks_provider( + // JwtProviderName::Google, + // Box::new(GoogleOidcJwks::new(google.issuer, google.well_known_uri)), + // ) + // } tracing::debug!(?config.spec, "Spec configuration"); diff --git a/rfd-api/src/seed.rs b/rfd-api/src/seed.rs index 17321df..9aaefdd 100644 --- a/rfd-api/src/seed.rs +++ b/rfd-api/src/seed.rs @@ -1,5 +1,5 @@ use chrono::{Duration, Utc}; -use rfd_model::{storage::StoreError, ApiUser, NewApiUser, NewApiUserToken}; +use rfd_model::{storage::StoreError, ApiUser, NewApiUser, NewApiKey}; use serde::Serialize; use thiserror::Error; use uuid::Uuid; @@ -7,14 +7,14 @@ use uuid::Uuid; use crate::{ authn::key::NewApiKey, context::ApiContext, - endpoints::api_user::InitialApiUserTokenResponse, + endpoints::api_user::InitialApiKKeyResponse, permissions::{ApiPermission, ApiUserPermission, RfdPermission}, }; #[derive(Debug, Serialize)] pub struct SeedApiUser { pub user: ApiUser, - pub token: InitialApiUserTokenResponse, + pub token: InitialApiKKeyResponse, } #[derive(Debug, Error)] @@ -53,7 +53,7 @@ pub async fn seed(ctx: &ApiContext) -> Result { let stored_token = ctx .create_api_user_token( - NewApiUserToken { + NewApiKey { id: token_id, api_user_id: user.id, token: hash, @@ -66,7 +66,7 @@ pub async fn seed(ctx: &ApiContext) -> Result { Ok(SeedApiUser { user, - token: InitialApiUserTokenResponse { + token: InitialApiKKeyResponse { id: stored_token.id, token, permissions: stored_token.permissions, diff --git a/rfd-api/src/server.rs b/rfd-api/src/server.rs index d79b779..a189364 100644 --- a/rfd-api/src/server.rs +++ b/rfd-api/src/server.rs @@ -13,10 +13,9 @@ use crate::{ create_api_user, create_api_user_token, delete_api_user_token, get_api_user, get_api_user_token, get_self, list_api_user_tokens, update_api_user, }, - login::{ - access_token::access_token_login, - jwt::jwt_login, - oauth::{exchange_device_token, get_device_provider}, + login::oauth::{ + authz_code::{authz_code_exchange, authz_code_redirect, authz_code_return}, + device_token::{exchange_device_token, get_device_provider}, }, rfd::get_rfd, webhook::github_webhook, @@ -82,11 +81,10 @@ pub fn server( api.register(create_api_user_token).unwrap(); api.register(delete_api_user_token).unwrap(); - // Access Token Login - api.register(access_token_login).unwrap(); - - // JWT Login - api.register(jwt_login).unwrap(); + // OAuth Authorization Login + api.register(authz_code_redirect).unwrap(); + api.register(authz_code_return).unwrap(); + api.register(authz_code_exchange).unwrap(); // OAuth Device Login api.register(get_device_provider).unwrap(); diff --git a/rfd-api/src/util.rs b/rfd-api/src/util.rs index 183f5e5..2fe26fa 100644 --- a/rfd-api/src/util.rs +++ b/rfd-api/src/util.rs @@ -1,6 +1,7 @@ pub mod response { use dropshot::HttpError; use http::StatusCode; + use std::error::Error; pub fn unauthorized() -> HttpError { client_error(StatusCode::UNAUTHORIZED, "Unauthorized") @@ -24,6 +25,14 @@ pub mod response { HttpError::for_not_found(None, internal_message.to_string()) } + pub fn to_internal_error(error: E) -> HttpError + where + E: Error, + { + tracing::info!(?error, "Request failed"); + internal_error(String::new()) + } + pub fn internal_error(internal_message: S) -> HttpError where S: ToString, diff --git a/rfd-model/Cargo.toml b/rfd-model/Cargo.toml index 5547dd2..32a0986 100644 --- a/rfd-model/Cargo.toml +++ b/rfd-model/Cargo.toml @@ -14,6 +14,7 @@ async-trait = { workspace = true } bb8 = { workspace = true } chrono = { workspace = true, features = ["serde"] } diesel = { workspace = true, features = ["chrono", "uuid", "serde_json"] } +http = { workspace = true } mockall = { workspace = true, optional = true } partial-struct = { workspace = true } schemars = { workspace = true, features = ["chrono", "uuid1"] } diff --git a/rfd-model/migrations/2023-01-03-032421_api_user/down.sql b/rfd-model/migrations/2023-01-03-032421_api_user/down.sql index 5994466..89c425f 100644 --- a/rfd-model/migrations/2023-01-03-032421_api_user/down.sql +++ b/rfd-model/migrations/2023-01-03-032421_api_user/down.sql @@ -1,4 +1,4 @@ -DROP TABLE api_user_token; +DROP TABLE api_key; DROP TABLE api_user_provider; diff --git a/rfd-model/migrations/2023-01-03-032421_api_user/up.sql b/rfd-model/migrations/2023-01-03-032421_api_user/up.sql index 8c337fc..d5dea2d 100644 --- a/rfd-model/migrations/2023-01-03-032421_api_user/up.sql +++ b/rfd-model/migrations/2023-01-03-032421_api_user/up.sql @@ -6,10 +6,10 @@ CREATE TABLE api_user ( deleted_at TIMESTAMPTZ ); -CREATE TABLE api_user_token ( +CREATE TABLE api_key ( id UUID PRIMARY KEY, api_user_id UUID REFERENCES api_user (id) NOT NULL, - token TEXT NOT NULL UNIQUE, + key TEXT NOT NULL UNIQUE, permissions JSONB NOT NULL, expires_at TIMESTAMPTZ NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), @@ -28,5 +28,4 @@ CREATE TABLE api_user_provider ( deleted_at TIMESTAMPTZ ); -CREATE UNIQUE INDEX api_user_token_idx ON api_user_token (api_user_id, token); CREATE UNIQUE INDEX api_user_provider_idx ON api_user_provider (provider, provider_id); diff --git a/rfd-model/migrations/2023-09-06-193013_login_attempt/down.sql b/rfd-model/migrations/2023-09-06-193013_login_attempt/down.sql new file mode 100644 index 0000000..1669499 --- /dev/null +++ b/rfd-model/migrations/2023-09-06-193013_login_attempt/down.sql @@ -0,0 +1,2 @@ +DROP TABLE login_attempt; +DROP TYPE ATTEMPT_STATE; diff --git a/rfd-model/migrations/2023-09-06-193013_login_attempt/up.sql b/rfd-model/migrations/2023-09-06-193013_login_attempt/up.sql new file mode 100644 index 0000000..c6baad7 --- /dev/null +++ b/rfd-model/migrations/2023-09-06-193013_login_attempt/up.sql @@ -0,0 +1,22 @@ +CREATE TYPE ATTEMPT_STATE as ENUM('new', 'remote_authenticated', 'failed', 'complete'); + +CREATE TABLE login_attempt( + id UUID PRIMARY KEY, + attempt_state ATTEMPT_STATE NOT NULL, + + client_id UUID NOT NULL, + redirect_uri VARCHAR NOT NULL, + state VARCHAR, + pkce_challenge VARCHAR, + pkce_challenge_method VARCHAR, + authz_code VARCHAR, + expires_at TIMESTAMPTZ, + + provider VARCHAR NOT NULL, + provider_state VARCHAR NOT NULL UNIQUE, + provider_pkce_verifier VARCHAR NOT NULL, + provider_authz_code VARCHAR, + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +) diff --git a/rfd-model/src/db.rs b/rfd-model/src/db.rs index b004a20..6381eae 100644 --- a/rfd-model/src/db.rs +++ b/rfd-model/src/db.rs @@ -6,10 +6,10 @@ use uuid::Uuid; use crate::{ permissions::Permissions, schema::{ - api_user, api_user_access_token, api_user_provider, api_user_token, job, rfd, rfd_pdf, - rfd_revision, + api_key, api_user, api_user_access_token, api_user_provider, job, login_attempt, rfd, + rfd_pdf, rfd_revision, }, - schema_ext::{ContentFormat, PdfSource}, + schema_ext::{ContentFormat, LoginAttemptState, PdfSource}, }; #[derive(Debug, Deserialize, Serialize, Queryable, Insertable)] @@ -82,11 +82,11 @@ pub struct ApiUserModel { } #[derive(Debug, Deserialize, Serialize, Queryable, Insertable)] -#[diesel(table_name = api_user_token)] -pub struct ApiUserTokenModel { +#[diesel(table_name = api_key)] +pub struct ApiKeyModel { pub id: Uuid, pub api_user_id: Uuid, - pub token: String, + pub key: String, pub permissions: Permissions, pub expires_at: DateTime, pub created_at: DateTime, @@ -116,3 +116,23 @@ pub struct ApiUserAccessTokenModel { pub created_at: DateTime, pub updated_at: DateTime, } + +#[derive(Debug, Deserialize, Serialize, Queryable, Insertable)] +#[diesel(table_name = login_attempt)] +pub struct LoginAttemptModel { + pub id: Uuid, + pub attempt_state: LoginAttemptState, + pub client_id: Uuid, + pub redirect_uri: String, + pub state: Option, + pub pkce_challenge: Option, + pub pkce_challenge_method: Option, + pub authz_code: Option, + pub expires_at: Option>, + pub provider: String, + pub provider_state: String, + pub provider_pkce_verifier: String, + pub provider_authz_code: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/rfd-model/src/lib.rs b/rfd-model/src/lib.rs index c267e07..c09a29d 100644 --- a/rfd-model/src/lib.rs +++ b/rfd-model/src/lib.rs @@ -1,10 +1,13 @@ +use std::{collections::BTreeMap, fmt::Display}; + use chrono::{DateTime, Utc}; -use db::{JobModel, RfdModel, RfdPdfModel, RfdRevisionModel}; +use db::{JobModel, LoginAttemptModel, RfdModel, RfdPdfModel, RfdRevisionModel}; use partial_struct::partial; use permissions::Permissions; -use schema_ext::{ContentFormat, PdfSource}; +use schema_ext::{ContentFormat, LoginAttemptState, PdfSource}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use thiserror::Error; use uuid::Uuid; pub mod db; @@ -181,19 +184,19 @@ pub struct ApiUserProvider { pub deleted_at: Option>, } -#[partial(NewApiUserToken)] +#[partial(NewApiKey)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)] -pub struct ApiUserToken { +pub struct ApiKey { pub id: Uuid, pub api_user_id: Uuid, - pub token: String, + pub key: String, pub permissions: Permissions, pub expires_at: DateTime, - #[partial(NewApiUserToken(skip))] + #[partial(NewApiKey(skip))] pub created_at: DateTime, - #[partial(NewApiUserToken(skip))] + #[partial(NewApiKey(skip))] pub updated_at: DateTime, - #[partial(NewApiUserToken(skip))] + #[partial(NewApiKey(skip))] pub deleted_at: Option>, } @@ -208,3 +211,107 @@ pub struct AccessToken { #[partial(NewAccessToken(skip))] pub updated_at: DateTime, } + +#[partial(NewLoginAttempt)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct LoginAttempt { + pub id: Uuid, + pub attempt_state: LoginAttemptState, + pub client_id: Uuid, + pub redirect_uri: String, + pub state: Option, + pub pkce_challenge: Option, + pub pkce_challenge_method: Option, + pub authz_code: Option, + pub expires_at: Option>, + pub provider: String, + pub provider_state: String, + pub provider_pkce_verifier: String, + pub provider_authz_code: Option, + #[partial(NewLoginAttempt(skip))] + pub created_at: DateTime, + #[partial(NewLoginAttempt(skip))] + pub updated_at: DateTime, +} + +impl LoginAttempt { + pub fn callback_url(&self) -> String { + let mut params = BTreeMap::new(); + + if let Some(state) = &self.state { + params.insert("state", state); + } + + if let Some(authz_code) = &self.authz_code { + params.insert("code", authz_code); + } + + let query_string = params + .into_iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join("&"); + + [self.redirect_uri.as_str(), query_string.as_str()].join("?") + } +} + +#[derive(Debug, Error)] +pub struct InvalidValueError { + pub field: String, + pub error: String, +} + +impl Display for InvalidValueError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} has an invalid value: {}", self.field, self.error) + } +} + +impl NewLoginAttempt { + pub fn new( + client_id: Uuid, + redirect_uri: String, + provider: String, + provider_state: String, + provider_pkce_verifier: String, + ) -> Result { + Ok(Self { + id: Uuid::new_v4(), + attempt_state: LoginAttemptState::New, + client_id, + redirect_uri, + state: None, + pkce_challenge: None, + pkce_challenge_method: None, + authz_code: None, + expires_at: None, + provider, + provider_state, + provider_pkce_verifier, + provider_authz_code: None, + }) + } +} + +impl From for LoginAttempt { + fn from(value: LoginAttemptModel) -> Self { + Self { + id: value.id, + attempt_state: value.attempt_state, + client_id: value.client_id, + redirect_uri: value.redirect_uri, + state: value.state, + pkce_challenge: value.pkce_challenge, + pkce_challenge_method: value.pkce_challenge_method, + authz_code: value.authz_code, + expires_at: value.expires_at, + provider: value.provider, + provider_state: value.provider_state, + provider_pkce_verifier: value.provider_pkce_verifier, + provider_authz_code: value.provider_authz_code, + created_at: value.created_at, + updated_at: value.updated_at, + } + } +} diff --git a/rfd-model/src/schema.rs b/rfd-model/src/schema.rs index bb9a44d..ad5bc1c 100644 --- a/rfd-model/src/schema.rs +++ b/rfd-model/src/schema.rs @@ -2,12 +2,8 @@ pub mod sql_types { #[derive(diesel::sql_types::SqlType)] - #[diesel(postgres_type(name = "dispatch_mode"))] - pub struct DispatchMode; - - #[derive(diesel::sql_types::SqlType)] - #[diesel(postgres_type(name = "dispatch_status"))] - pub struct DispatchStatus; + #[diesel(postgres_type(name = "attempt_state"))] + pub struct AttemptState; #[derive(diesel::sql_types::SqlType)] #[diesel(postgres_type(name = "rfd_content_format"))] @@ -19,14 +15,15 @@ pub mod sql_types { } diesel::table! { - allow_list (id) { - id -> Int4, - username -> Varchar, - #[sql_name = "type"] - type_ -> Varchar, - rules -> Array>, + api_key (id) { + id -> Uuid, + api_user_id -> Uuid, + key -> Text, + permissions -> Jsonb, + expires_at -> Timestamptz, created_at -> Timestamptz, updated_at -> Timestamptz, + deleted_at -> Nullable, } } @@ -63,44 +60,6 @@ diesel::table! { } } -diesel::table! { - api_user_token (id) { - id -> Uuid, - api_user_id -> Uuid, - token -> Text, - permissions -> Jsonb, - expires_at -> Timestamptz, - created_at -> Timestamptz, - updated_at -> Timestamptz, - deleted_at -> Nullable, - } -} - -diesel::table! { - use diesel::sql_types::*; - use super::sql_types::DispatchMode; - use super::sql_types::DispatchStatus; - - dispatch (id) { - id -> Int4, - dispatch_id -> Uuid, - mode -> DispatchMode, - pattern -> Varchar, - workflow -> Int8, - owner -> Varchar, - repository -> Varchar, - #[sql_name = "ref"] - ref_ -> Varchar, - response_status -> Int4, - duration -> Int8, - created_at -> Timestamptz, - source -> Nullable, - requires_token -> Bool, - status -> DispatchStatus, - scheduled_for -> Timestamptz, - } -} - diesel::table! { job (id) { id -> Int4, @@ -116,6 +75,29 @@ diesel::table! { } } +diesel::table! { + use diesel::sql_types::*; + use super::sql_types::AttemptState; + + login_attempt (id) { + id -> Uuid, + attempt_state -> AttemptState, + client_id -> Uuid, + redirect_uri -> Varchar, + state -> Nullable, + pkce_challenge -> Nullable, + pkce_challenge_method -> Nullable, + authz_code -> Nullable, + expires_at -> Nullable, + provider -> Varchar, + provider_state -> Varchar, + provider_pkce_verifier -> Varchar, + provider_authz_code -> Nullable, + created_at -> Timestamptz, + updated_at -> Timestamptz, + } +} + diesel::table! { rfd (id) { id -> Uuid, @@ -164,39 +146,20 @@ diesel::table! { } } -diesel::table! { - rule (id) { - id -> Int4, - pattern -> Varchar, - target_repository -> Varchar, - target_owner -> Varchar, - target_ref -> Varchar, - target_workflow -> Int8, - enabled -> Bool, - created_at -> Timestamptz, - updated_at -> Timestamptz, - requires_token -> Bool, - debounce -> Int4, - conditions -> Nullable, - } -} - +diesel::joinable!(api_key -> api_user (api_user_id)); diesel::joinable!(api_user_access_token -> api_user (api_user_id)); diesel::joinable!(api_user_provider -> api_user (api_user_id)); -diesel::joinable!(api_user_token -> api_user (api_user_id)); diesel::joinable!(rfd_pdf -> rfd_revision (rfd_revision_id)); diesel::joinable!(rfd_revision -> rfd (rfd_id)); diesel::allow_tables_to_appear_in_same_query!( - allow_list, + api_key, api_user, api_user_access_token, api_user_provider, - api_user_token, - dispatch, job, + login_attempt, rfd, rfd_pdf, rfd_revision, - rule, ); diff --git a/rfd-model/src/schema_ext.rs b/rfd-model/src/schema_ext.rs index 7e8f5ed..c91358b 100644 --- a/rfd-model/src/schema_ext.rs +++ b/rfd-model/src/schema_ext.rs @@ -16,7 +16,7 @@ use std::{ use crate::{ permissions::Permissions, - schema::sql_types::{RfdContentFormat, RfdPdfSource}, + schema::sql_types::{AttemptState, RfdContentFormat, RfdPdfSource}, }; macro_rules! sql_conversion { @@ -115,3 +115,38 @@ where Ok(serde_json::from_value(value)?) } } + +#[derive(Debug, PartialEq, Clone, FromSqlRow, AsExpression, Serialize, Deserialize, JsonSchema)] +#[diesel(sql_type = AttemptState)] +#[serde(rename_all = "lowercase")] +pub enum LoginAttemptState { + Complete, + Failed, + New, + RemoteAuthenticated, +} + +sql_conversion! { + AttemptState => LoginAttemptState, + Complete => b"complete", + Failed => b"failed", + New => b"new", + RemoteAuthenticated => b"remote_authenticated", +} + +impl Display for LoginAttemptState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LoginAttemptState::Complete => write!(f, "complete"), + LoginAttemptState::Failed => write!(f, "failed"), + LoginAttemptState::New => write!(f, "new"), + LoginAttemptState::RemoteAuthenticated => write!(f, "remote_authenticated"), + } + } +} + +impl Default for LoginAttemptState { + fn default() -> Self { + Self::New + } +} diff --git a/rfd-model/src/storage/mod.rs b/rfd-model/src/storage/mod.rs index aefa082..36f2b83 100644 --- a/rfd-model/src/storage/mod.rs +++ b/rfd-model/src/storage/mod.rs @@ -9,9 +9,11 @@ use thiserror::Error; use uuid::Uuid; use crate::{ - permissions::Permission, schema_ext::PdfSource, AccessToken, ApiUser, ApiUserProvider, - ApiUserToken, Job, NewAccessToken, NewApiUser, NewApiUserProvider, NewApiUserToken, NewJob, - NewRfd, NewRfdPdf, NewRfdRevision, Rfd, RfdPdf, RfdRevision, + permissions::Permission, + schema_ext::{LoginAttemptState, PdfSource}, + AccessToken, ApiKey, ApiUser, ApiUserProvider, Job, LoginAttempt, NewAccessToken, NewApiKey, + NewApiUser, NewApiUserProvider, NewJob, NewLoginAttempt, NewRfd, NewRfdPdf, NewRfdRevision, + Rfd, RfdPdf, RfdRevision, }; pub mod postgres; @@ -242,27 +244,28 @@ pub trait ApiUserStore { } #[derive(Debug, Default)] -pub struct ApiUserTokenFilter { +pub struct ApiKeyFilter { pub api_user_id: Option>, + pub key: Option>, pub expired: bool, pub deleted: bool, } #[cfg_attr(feature = "mock", automock)] #[async_trait] -pub trait ApiUserTokenStore { - async fn get(&self, id: &Uuid, deleted: bool) -> Result>, StoreError>; +pub trait ApiKeyStore { + async fn get(&self, id: &Uuid, deleted: bool) -> Result>, StoreError>; async fn list( &self, - filter: ApiUserTokenFilter, + filter: ApiKeyFilter, pagination: &ListPagination, - ) -> Result>, StoreError>; + ) -> Result>, StoreError>; async fn upsert( &self, - token: NewApiUserToken, + token: NewApiKey, api_user: &ApiUser, - ) -> Result, StoreError>; - async fn delete(&self, id: &Uuid) -> Result>, StoreError>; + ) -> Result, StoreError>; + async fn delete(&self, id: &Uuid) -> Result>, StoreError>; } #[derive(Debug, Default)] @@ -306,3 +309,23 @@ pub trait AccessTokenStore { ) -> Result, StoreError>; async fn upsert(&self, token: NewAccessToken) -> Result; } + +#[derive(Debug, Default)] +pub struct LoginAttemptFilter { + pub id: Option>, + pub client_id: Option>, + pub attempt_state: Option>, + pub authz_code: Option>, +} + +#[cfg_attr(feature = "mock", automock)] +#[async_trait] +pub trait LoginAttemptStore { + async fn get(&self, id: &Uuid) -> Result, StoreError>; + async fn list( + &self, + filter: LoginAttemptFilter, + pagination: &ListPagination, + ) -> Result, StoreError>; + async fn upsert(&self, attempt: NewLoginAttempt) -> Result; +} diff --git a/rfd-model/src/storage/postgres.rs b/rfd-model/src/storage/postgres.rs index 5138c56..4eeb618 100644 --- a/rfd-model/src/storage/postgres.rs +++ b/rfd-model/src/storage/postgres.rs @@ -17,25 +17,25 @@ use uuid::Uuid; use crate::{ db::{ - ApiUserAccessTokenModel, ApiUserModel, ApiUserProviderModel, ApiUserTokenModel, JobModel, - RfdModel, RfdPdfModel, RfdRevisionModel, + ApiKeyModel, ApiUserAccessTokenModel, ApiUserModel, ApiUserProviderModel, JobModel, + LoginAttemptModel, RfdModel, RfdPdfModel, RfdRevisionModel, }, permissions::{Permission, Permissions}, schema::{ - api_user, api_user_access_token, api_user_provider, api_user_token, job, rfd, rfd_pdf, - rfd_revision, + api_key, api_user, api_user_access_token, api_user_provider, job, login_attempt, rfd, + rfd_pdf, rfd_revision, }, storage::StoreError, - AccessToken, ApiUser, ApiUserProvider, ApiUserToken, Job, NewAccessToken, NewApiUser, - NewApiUserProvider, NewApiUserToken, NewJob, NewRfd, NewRfdPdf, NewRfdRevision, Rfd, RfdPdf, - RfdRevision, + AccessToken, ApiKey, ApiUser, ApiUserProvider, Job, LoginAttempt, NewAccessToken, NewApiKey, + NewApiUser, NewApiUserProvider, NewJob, NewLoginAttempt, NewRfd, NewRfdPdf, NewRfdRevision, + Rfd, RfdPdf, RfdRevision, }; use super::{ - AccessTokenFilter, AccessTokenStore, ApiUserFilter, ApiUserProviderFilter, - ApiUserProviderStore, ApiUserStore, ApiUserTokenFilter, ApiUserTokenStore, JobFilter, JobStore, - ListPagination, RfdFilter, RfdPdfFilter, RfdPdfStore, RfdRevisionFilter, RfdRevisionStore, - RfdStore, + AccessTokenFilter, AccessTokenStore, ApiKeyFilter, ApiKeyStore, ApiUserFilter, + ApiUserProviderFilter, ApiUserProviderStore, ApiUserStore, JobFilter, JobStore, ListPagination, + LoginAttemptFilter, LoginAttemptStore, RfdFilter, RfdPdfFilter, RfdPdfStore, RfdRevisionFilter, + RfdRevisionStore, RfdStore, }; pub type DbPool = Pool>; @@ -512,74 +512,79 @@ where } #[async_trait] -impl ApiUserTokenStore for PostgresStore +impl ApiKeyStore for PostgresStore where T: Permission, { - async fn get(&self, id: &Uuid, deleted: bool) -> Result>, StoreError> { - let mut query = api_user_token::dsl::api_user_token + async fn get(&self, id: &Uuid, deleted: bool) -> Result>, StoreError> { + let mut query = api_key::dsl::api_key .into_boxed() - .filter(api_user_token::id.eq(*id)); + .filter(api_key::id.eq(*id)); if !deleted { - query = query.filter(api_user_token::deleted_at.is_null()); + query = query.filter(api_key::deleted_at.is_null()); } let result = query - .get_result_async::>(&self.conn) + .get_result_async::>(&self.conn) .await .optional()?; - Ok(result.map(|token| ApiUserToken { - id: token.id, - api_user_id: token.api_user_id, - token: token.token, - permissions: token.permissions, - expires_at: token.expires_at, - created_at: token.created_at, - updated_at: token.updated_at, - deleted_at: token.deleted_at, + Ok(result.map(|key| ApiKey { + id: key.id, + api_user_id: key.api_user_id, + key: key.key, + permissions: key.permissions, + expires_at: key.expires_at, + created_at: key.created_at, + updated_at: key.updated_at, + deleted_at: key.deleted_at, })) } async fn list( &self, - filter: ApiUserTokenFilter, + filter: ApiKeyFilter, pagination: &ListPagination, - ) -> Result>, StoreError> { - let mut query = api_user_token::dsl::api_user_token.into_boxed(); + ) -> Result>, StoreError> { + let mut query = api_key::dsl::api_key.into_boxed(); - let ApiUserTokenFilter { + let ApiKeyFilter { api_user_id, + key, expired, deleted, } = filter; if let Some(api_user_id) = api_user_id { - query = query.filter(api_user_token::api_user_id.eq_any(api_user_id)); + query = query.filter(api_key::api_user_id.eq_any(api_user_id)); + } + + if let Some(key) = key { + query = query.filter(api_key::key.eq_any(key)); } if !expired { - query = query.filter(api_user_token::expires_at.gt(Utc::now())); + query = query.filter(api_key::expires_at.gt(Utc::now())); } if !deleted { - query = query.filter(api_user_token::deleted_at.is_null()); + query = query.filter(api_key::deleted_at.is_null()); } let results = query .offset(pagination.offset) .limit(pagination.limit) - .order(api_user_token::created_at.desc()) - .get_results_async::>(&self.conn) + .order(api_key::created_at.desc()) + .get_results_async::>(&self.conn) .await?; Ok(results .into_iter() - .map(|token| ApiUserToken { + .map(|token| ApiKey { id: token.id, api_user_id: token.api_user_id, - token: token.token, + key: token.key, permissions: token.permissions, expires_at: token.expires_at, created_at: token.created_at, @@ -591,11 +596,11 @@ where async fn upsert( &self, - token: NewApiUserToken, + key: NewApiKey, api_user: &ApiUser, - ) -> Result, StoreError> { + ) -> Result, StoreError> { // Validate the the token permissions are a subset of the users permissions - let permissions: Permissions = token + let permissions: Permissions = key .permissions .inner() .iter() @@ -612,37 +617,37 @@ where .collect::>() .into(); - let token_m: ApiUserTokenModel = insert_into(api_user_token::dsl::api_user_token) + let key_m: ApiKeyModel = insert_into(api_key::dsl::api_key) .values(( - api_user_token::id.eq(token.id), - api_user_token::api_user_id.eq(token.api_user_id), - api_user_token::token.eq(token.token.clone()), - api_user_token::expires_at.eq(token.expires_at), - api_user_token::permissions.eq(permissions), + api_key::id.eq(key.id), + api_key::api_user_id.eq(key.api_user_id), + api_key::key.eq(key.key.clone()), + api_key::expires_at.eq(key.expires_at), + api_key::permissions.eq(permissions), )) .get_result_async(&self.conn) .await?; - Ok(ApiUserToken { - id: token_m.id, - api_user_id: token_m.api_user_id, - token: token_m.token, - permissions: token_m.permissions, - expires_at: token_m.expires_at, - created_at: token_m.created_at, - updated_at: token_m.updated_at, - deleted_at: token_m.deleted_at, + Ok(ApiKey { + id: key_m.id, + api_user_id: key_m.api_user_id, + key: key_m.key, + permissions: key_m.permissions, + expires_at: key_m.expires_at, + created_at: key_m.created_at, + updated_at: key_m.updated_at, + deleted_at: key_m.deleted_at, }) } - async fn delete(&self, id: &Uuid) -> Result>, StoreError> { - let _ = update(api_user_token::dsl::api_user_token) - .filter(api_user_token::id.eq(*id)) - .set(api_user_token::deleted_at.eq(Utc::now())) + async fn delete(&self, id: &Uuid) -> Result>, StoreError> { + let _ = update(api_key::dsl::api_key) + .filter(api_key::id.eq(*id)) + .set(api_key::deleted_at.eq(Utc::now())) .execute_async(&self.conn) .await?; - ApiUserTokenStore::get(self, id, true).await + ApiKeyStore::get(self, id, true).await } } @@ -864,3 +869,91 @@ impl AccessTokenStore for PostgresStore { }) } } + +#[async_trait] +impl LoginAttemptStore for PostgresStore { + async fn get(&self, id: &Uuid) -> Result, StoreError> { + let query = login_attempt::dsl::login_attempt + .into_boxed() + .filter(login_attempt::id.eq(*id)); + + let result = query + .get_result_async::(&self.conn) + .await + .optional()?; + + Ok(result.map(|attempt| attempt.into())) + } + + async fn list( + &self, + filter: LoginAttemptFilter, + pagination: &ListPagination, + ) -> Result, StoreError> { + let mut query = login_attempt::dsl::login_attempt.into_boxed(); + + let LoginAttemptFilter { + id, + client_id, + attempt_state, + authz_code, + } = filter; + + if let Some(id) = id { + query = query.filter(login_attempt::id.eq_any(id)); + } + + if let Some(client_id) = client_id { + query = query.filter(login_attempt::client_id.eq_any(client_id)); + } + + if let Some(attempt_state) = attempt_state { + query = query.filter(login_attempt::attempt_state.eq_any(attempt_state)); + } + + if let Some(authz_code) = authz_code { + query = query.filter(login_attempt::authz_code.eq_any(authz_code)); + } + + let results = query + .offset(pagination.offset) + .limit(pagination.limit) + .order(login_attempt::created_at.desc()) + .get_results_async::(&self.conn) + .await?; + + Ok(results.into_iter().map(|model| model.into()).collect()) + } + + async fn upsert(&self, attempt: NewLoginAttempt) -> Result { + let attempt_m: LoginAttemptModel = insert_into(login_attempt::dsl::login_attempt) + .values(( + login_attempt::id.eq(attempt.id), + login_attempt::attempt_state.eq(attempt.attempt_state), + login_attempt::client_id.eq(attempt.client_id), + login_attempt::redirect_uri.eq(attempt.redirect_uri), + login_attempt::state.eq(attempt.state), + login_attempt::pkce_challenge.eq(attempt.pkce_challenge), + login_attempt::pkce_challenge_method.eq(attempt.pkce_challenge_method), + login_attempt::authz_code.eq(attempt.authz_code), + login_attempt::expires_at.eq(attempt.expires_at), + login_attempt::provider.eq(attempt.provider), + login_attempt::provider_state.eq(attempt.provider_state), + login_attempt::provider_pkce_verifier.eq(attempt.provider_pkce_verifier), + login_attempt::provider_authz_code.eq(attempt.provider_authz_code), + )) + .on_conflict(login_attempt::id) + .do_update() + .set(( + login_attempt::attempt_state.eq(excluded(login_attempt::attempt_state)), + login_attempt::authz_code.eq(excluded(login_attempt::authz_code)), + login_attempt::expires_at.eq(excluded(login_attempt::expires_at)), + login_attempt::provider_authz_code.eq(excluded(login_attempt::provider_authz_code)), + login_attempt::updated_at.eq(Utc::now()), + )) + .get_result_async(&self.conn) + .await?; + + Ok(attempt_m.into()) + } +} diff --git a/rfd-model/tests/postgres.rs b/rfd-model/tests/postgres.rs index a4717e0..a7fade8 100644 --- a/rfd-model/tests/postgres.rs +++ b/rfd-model/tests/postgres.rs @@ -8,10 +8,10 @@ use diesel::{ use diesel_migrations::{embed_migrations, EmbeddedMigrations}; use rfd_model::{ storage::{ - postgres::PostgresStore, ApiUserFilter, ApiUserStore, ApiUserTokenFilter, - ApiUserTokenStore, ListPagination, + postgres::PostgresStore, ApiKeyFilter, ApiKeyStore, ApiUserFilter, ApiUserStore, + ListPagination, }, - NewApiUser, NewApiUserToken, + NewApiKey, NewApiUser, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -27,9 +27,9 @@ fn leakable_dbs() -> Vec { #[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, JsonSchema)] enum TestPermission { CreateApiUser, - CreateApiUserToken(Uuid), - GetApiUserToken(Uuid), - DeleteApiUserToken(Uuid), + CreateApiKey(Uuid), + GetApiKey(Uuid), + DeleteApiKey(Uuid), } // A fresh test database that will be created and migrated for use in a test. At the end of the @@ -123,7 +123,7 @@ async fn test_api_user() { &store, NewApiUser { id: api_user_id, - permissions: vec![TestPermission::CreateApiUserToken(api_user_id).into()].into(), + permissions: vec![TestPermission::CreateApiKey(api_user_id).into()].into(), }, ) .await @@ -142,7 +142,7 @@ async fn test_api_user() { &store, NewApiUser { id: api_user_id, - permissions: vec![TestPermission::CreateApiUserToken(api_user_id).into()].into(), + permissions: vec![TestPermission::CreateApiKey(api_user_id).into()].into(), }, ) .await @@ -156,9 +156,9 @@ async fn test_api_user() { NewApiUser { id: api_user_id, permissions: vec![ - TestPermission::CreateApiUserToken(api_user_id).into(), - TestPermission::GetApiUserToken(api_user_id).into(), - TestPermission::DeleteApiUserToken(api_user_id).into(), + TestPermission::CreateApiKey(api_user_id).into(), + TestPermission::GetApiKey(api_user_id).into(), + TestPermission::DeleteApiKey(api_user_id).into(), ] .into(), }, @@ -168,19 +168,19 @@ async fn test_api_user() { assert!(api_user .permissions - .can(&TestPermission::GetApiUserToken(api_user_id).into())); + .can(&TestPermission::GetApiKey(api_user_id).into())); assert!(api_user .permissions - .can(&TestPermission::DeleteApiUserToken(api_user_id).into())); + .can(&TestPermission::DeleteApiKey(api_user_id).into())); // 5. Create an API token for the user - let token = ApiUserTokenStore::upsert( + let token = ApiKeyStore::upsert( &store, - NewApiUserToken { + NewApiKey { id: Uuid::new_v4(), api_user_id: api_user.id, - token: format!("token-{}", Uuid::new_v4()), - permissions: vec![TestPermission::GetApiUserToken(api_user_id).into()].into(), + key: format!("key-{}", Uuid::new_v4()), + permissions: vec![TestPermission::GetApiKey(api_user_id).into()].into(), expires_at: Utc::now() + Duration::seconds(5 * 60), }, &api_user, @@ -189,15 +189,15 @@ async fn test_api_user() { .unwrap(); // 6. Create an API token with excess permissions for the user - let excess_token = ApiUserTokenStore::upsert( + let excess_token = ApiKeyStore::upsert( &store, - NewApiUserToken { + NewApiKey { id: Uuid::new_v4(), api_user_id: api_user.id, - token: format!("token-{}", Uuid::new_v4()), + key: format!("key-{}", Uuid::new_v4()), permissions: vec![ TestPermission::CreateApiUser.into(), - TestPermission::GetApiUserToken(api_user_id).into(), + TestPermission::GetApiKey(api_user_id).into(), ] .into(), expires_at: Utc::now() + Duration::seconds(5 * 60), @@ -212,15 +212,15 @@ async fn test_api_user() { .can(&TestPermission::CreateApiUser.into())); // 7. Create an API token with excess permissions for the user - let expired_token = ApiUserTokenStore::upsert( + let expired_token = ApiKeyStore::upsert( &store, - NewApiUserToken { + NewApiKey { id: Uuid::new_v4(), api_user_id: api_user.id, - token: format!("token-{}", Uuid::new_v4()), + key: format!("key-{}", Uuid::new_v4()), permissions: vec![ TestPermission::CreateApiUser.into(), - TestPermission::GetApiUserToken(api_user_id).into(), + TestPermission::GetApiKey(api_user_id).into(), ] .into(), expires_at: Utc::now() - Duration::seconds(5 * 60), @@ -233,10 +233,11 @@ async fn test_api_user() { assert!(expired_token.expires_at < Utc::now()); // 8. List the active API tokens for the user - let tokens = ApiUserTokenStore::list( + let tokens = ApiKeyStore::list( &store, - ApiUserTokenFilter { + ApiKeyFilter { api_user_id: Some(vec![api_user.id]), + key: None, expired: false, deleted: false, }, @@ -250,10 +251,11 @@ async fn test_api_user() { assert!(tokens.contains(&excess_token)); // 9. List all API tokens for the user - let all_tokens = ApiUserTokenStore::list( + let all_tokens = ApiKeyStore::list( &store, - ApiUserTokenFilter { + ApiKeyFilter { api_user_id: Some(vec![api_user.id]), + key: None, expired: true, deleted: false, }, @@ -266,7 +268,7 @@ async fn test_api_user() { assert!(all_tokens.contains(&expired_token)); // 10. Lookup an API token for the user - let token_lookup = ApiUserTokenStore::::get(&store, &token.id, false) + let token_lookup = ApiKeyStore::::get(&store, &token.id, false) .await .unwrap() .unwrap(); @@ -275,16 +277,17 @@ async fn test_api_user() { // 11. Delete the API tokens for the user for token in all_tokens { - let _ = ApiUserTokenStore::::delete(&store, &token.id) + let _ = ApiKeyStore::::delete(&store, &token.id) .await .unwrap(); } // 12. List the deleted API tokens for the user - let deleted_tokens = ApiUserTokenStore::::list( + let deleted_tokens = ApiKeyStore::::list( &store, - ApiUserTokenFilter { + ApiKeyFilter { api_user_id: Some(vec![api_user.id]), + key: None, expired: true, deleted: true, },