Skip to content

Commit

Permalink
Add tests for error code returns. Fix callback to forward access deni…
Browse files Browse the repository at this point in the history
…ed errors
  • Loading branch information
augustuswm committed Sep 26, 2023
1 parent cc13c54 commit 8831560
Showing 1 changed file with 181 additions and 2 deletions.
183 changes: 181 additions & 2 deletions rfd-api/src/endpoints/login/oauth/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,13 @@ pub async fn authz_code_callback_op(
// If we fail to find a matching attempt, there is not much we can do other than return
// unauthorized
unauthorized()
})
.and_then(|attempt| {
if attempt.attempt_state == LoginAttemptState::New {
Ok(attempt)
} else {
Err(unauthorized())
}
})?;

attempt = match (code, error) {
Expand All @@ -306,8 +313,16 @@ pub async fn authz_code_callback_op(
// returned or there is a missing code, then we can not report a successful process
attempt.provider_authz_code = code;

// When a user has explicitly denied access we want to forward that error message
// onwards to the upstream requester. All other errors should be opaque to the
// original requester and are returned as server errors
let error_message = match error.as_deref() {
Some("access_denied") => "access_denied",
_ => "server_error",
};

// TODO: Specialize the returned error
ctx.fail_login_attempt(attempt, Some("server_error"), error.as_deref())
ctx.fail_login_attempt(attempt, Some(error_message), error.as_deref())
.await
.map_err(to_internal_error)?
}
Expand Down Expand Up @@ -488,7 +503,7 @@ mod tests {
use dropshot::RequestInfo;
use http::{
header::{COOKIE, LOCATION, SET_COOKIE},
HeaderMap, HeaderValue,
HeaderValue, StatusCode,
};
use hyper::Body;
use mockall::predicate::eq;
Expand Down Expand Up @@ -530,6 +545,170 @@ mod tests {
assert_eq!(id, verify_csrf(&request, &query).unwrap());
}

#[tokio::test]
async fn test_callback_fails_when_not_in_new_state() {
let invalid_states = [
LoginAttemptState::Complete,
LoginAttemptState::Failed,
LoginAttemptState::RemoteAuthenticated,
];

for state in invalid_states {
let attempt_id = Uuid::new_v4();
let attempt = LoginAttempt {
id: attempt_id,
attempt_state: state,
client_id: Uuid::new_v4(),
redirect_uri: "https://test.oxeng.dev/callback".to_string(),
state: Some("ox_state".to_string()),
pkce_challenge: Some("ox_challenge".to_string()),
pkce_challenge_method: Some("S256".to_string()),
authz_code: None,
expires_at: None,
error: None,
provider: "google".to_string(),
provider_pkce_verifier: Some("rfd_verifier".to_string()),
provider_authz_code: None,
provider_error: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};

let mut storage = MockStorage::new();
let mut attempt_store = MockLoginAttemptStore::new();
attempt_store
.expect_get()
.with(eq(attempt.id))
.returning(move |_| Ok(Some(attempt.clone())));
storage.login_attempt_store = Some(Arc::new(attempt_store));

let ctx = mock_context(storage).await;
let err =
authz_code_callback_op(&ctx, &attempt_id, Some("remote-code".to_string()), None)
.await;

assert_eq!(StatusCode::UNAUTHORIZED, err.unwrap_err().status_code);
}
}

#[tokio::test]
async fn test_callback_fails_when_error_is_passed() {
let attempt_id = Uuid::new_v4();
let attempt = LoginAttempt {
id: attempt_id,
attempt_state: LoginAttemptState::New,
client_id: Uuid::new_v4(),
redirect_uri: "https://test.oxeng.dev/callback".to_string(),
state: Some("ox_state".to_string()),
pkce_challenge: Some("ox_challenge".to_string()),
pkce_challenge_method: Some("S256".to_string()),
authz_code: None,
expires_at: None,
error: None,
provider: "google".to_string(),
provider_pkce_verifier: Some("rfd_verifier".to_string()),
provider_authz_code: None,
provider_error: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};

let mut attempt_store = MockLoginAttemptStore::new();
let original_attempt = attempt.clone();
attempt_store
.expect_get()
.with(eq(attempt.id))
.returning(move |_| Ok(Some(original_attempt.clone())));

attempt_store
.expect_upsert()
.withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed)
.returning(move |arg| {
let mut returned = attempt.clone();
returned.attempt_state = arg.attempt_state;
returned.authz_code = arg.authz_code;
returned.error = arg.error;
Ok(returned)
});

let mut storage = MockStorage::new();
storage.login_attempt_store = Some(Arc::new(attempt_store));
let ctx = mock_context(storage).await;

let location = authz_code_callback_op(
&ctx,
&attempt_id,
Some("remote-code".to_string()),
Some("not_access_denied".to_string()),
)
.await
.unwrap();

assert_eq!(
format!("https://test.oxeng.dev/callback?error=server_error&state=ox_state",),
location
);
}

#[tokio::test]
async fn test_callback_forwards_access_denied() {
let attempt_id = Uuid::new_v4();
let attempt = LoginAttempt {
id: attempt_id,
attempt_state: LoginAttemptState::New,
client_id: Uuid::new_v4(),
redirect_uri: "https://test.oxeng.dev/callback".to_string(),
state: Some("ox_state".to_string()),
pkce_challenge: Some("ox_challenge".to_string()),
pkce_challenge_method: Some("S256".to_string()),
authz_code: None,
expires_at: None,
error: None,
provider: "google".to_string(),
provider_pkce_verifier: Some("rfd_verifier".to_string()),
provider_authz_code: None,
provider_error: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};

let mut attempt_store = MockLoginAttemptStore::new();
let original_attempt = attempt.clone();
attempt_store
.expect_get()
.with(eq(attempt.id))
.returning(move |_| Ok(Some(original_attempt.clone())));

attempt_store
.expect_upsert()
.withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed)
.returning(move |arg| {
let mut returned = attempt.clone();
returned.attempt_state = arg.attempt_state;
returned.authz_code = arg.authz_code;
returned.error = arg.error;
Ok(returned)
});

let mut storage = MockStorage::new();
storage.login_attempt_store = Some(Arc::new(attempt_store));
let ctx = mock_context(storage).await;

let location = authz_code_callback_op(
&ctx,
&attempt_id,
Some("remote-code".to_string()),
Some("access_denied".to_string()),
)
.await
.unwrap();

assert_eq!(
format!("https://test.oxeng.dev/callback?error=access_denied&state=ox_state",),
location
);
}

#[tokio::test]
async fn test_handles_callback_with_code() {
let attempt_id = Uuid::new_v4();
Expand Down

0 comments on commit 8831560

Please sign in to comment.