diff --git a/rfd-api/src/endpoints/login/oauth/code.rs b/rfd-api/src/endpoints/login/oauth/code.rs index a5ea790..3103e9e 100644 --- a/rfd-api/src/endpoints/login/oauth/code.rs +++ b/rfd-api/src/endpoints/login/oauth/code.rs @@ -287,6 +287,13 @@ pub async fn authz_code_callback_op( // If we fail to find a matching attempt, there is not much we can do other than return // unauthorized unauthorized() + }) + .and_then(|attempt| { + if attempt.attempt_state == LoginAttemptState::New { + Ok(attempt) + } else { + Err(unauthorized()) + } })?; attempt = match (code, error) { @@ -306,8 +313,16 @@ pub async fn authz_code_callback_op( // returned or there is a missing code, then we can not report a successful process attempt.provider_authz_code = code; + // When a user has explicitly denied access we want to forward that error message + // onwards to the upstream requester. All other errors should be opaque to the + // original requester and are returned as server errors + let error_message = match error.as_deref() { + Some("access_denied") => "access_denied", + _ => "server_error", + }; + // TODO: Specialize the returned error - ctx.fail_login_attempt(attempt, Some("server_error"), error.as_deref()) + ctx.fail_login_attempt(attempt, Some(error_message), error.as_deref()) .await .map_err(to_internal_error)? } @@ -488,7 +503,7 @@ mod tests { use dropshot::RequestInfo; use http::{ header::{COOKIE, LOCATION, SET_COOKIE}, - HeaderMap, HeaderValue, + HeaderValue, StatusCode, }; use hyper::Body; use mockall::predicate::eq; @@ -530,6 +545,170 @@ mod tests { assert_eq!(id, verify_csrf(&request, &query).unwrap()); } + #[tokio::test] + async fn test_callback_fails_when_not_in_new_state() { + let invalid_states = [ + LoginAttemptState::Complete, + LoginAttemptState::Failed, + LoginAttemptState::RemoteAuthenticated, + ]; + + for state in invalid_states { + let attempt_id = Uuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: state, + client_id: Uuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("rfd_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let mut storage = MockStorage::new(); + let mut attempt_store = MockLoginAttemptStore::new(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(attempt.clone()))); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + + let ctx = mock_context(storage).await; + let err = + authz_code_callback_op(&ctx, &attempt_id, Some("remote-code".to_string()), None) + .await; + + assert_eq!(StatusCode::UNAUTHORIZED, err.unwrap_err().status_code); + } + } + + #[tokio::test] + async fn test_callback_fails_when_error_is_passed() { + let attempt_id = Uuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id: Uuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("rfd_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_upsert() + .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) + .returning(move |arg| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + returned.error = arg.error; + Ok(returned) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + let ctx = mock_context(storage).await; + + let location = authz_code_callback_op( + &ctx, + &attempt_id, + Some("remote-code".to_string()), + Some("not_access_denied".to_string()), + ) + .await + .unwrap(); + + assert_eq!( + format!("https://test.oxeng.dev/callback?error=server_error&state=ox_state",), + location + ); + } + + #[tokio::test] + async fn test_callback_forwards_access_denied() { + let attempt_id = Uuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id: Uuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("rfd_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_upsert() + .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) + .returning(move |arg| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + returned.error = arg.error; + Ok(returned) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + let ctx = mock_context(storage).await; + + let location = authz_code_callback_op( + &ctx, + &attempt_id, + Some("remote-code".to_string()), + Some("access_denied".to_string()), + ) + .await + .unwrap(); + + assert_eq!( + format!("https://test.oxeng.dev/callback?error=access_denied&state=ox_state",), + location + ); + } + #[tokio::test] async fn test_handles_callback_with_code() { let attempt_id = Uuid::new_v4();