diff --git a/rfd-api/src/endpoints/login/oauth/code.rs b/rfd-api/src/endpoints/login/oauth/code.rs index 5bc6d44..4e57eca 100644 --- a/rfd-api/src/endpoints/login/oauth/code.rs +++ b/rfd-api/src/endpoints/login/oauth/code.rs @@ -1,7 +1,7 @@ use chrono::{Duration, Utc}; use dropshot::{ endpoint, http_response_temporary_redirect, HttpError, HttpResponseOk, - HttpResponseTemporaryRedirect, Path, Query, RequestContext, TypedBody, + HttpResponseTemporaryRedirect, Path, Query, RequestContext, RequestInfo, TypedBody, }; use http::{ header::{LOCATION, SET_COOKIE}, @@ -185,39 +185,16 @@ fn oauth_redirect_response( .body(Body::empty())?) } -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeReturnQuery { - pub state: Option, - pub code: Option, - pub error: Option, -} - -/// Handle return calls from a remote OAuth provider -#[endpoint { - method = GET, - path = "/login/oauth/{provider}/code/callback" -}] -#[instrument(skip(rqctx), fields(request_id = rqctx.request_id), err(Debug))] -pub async fn authz_code_callback( - rqctx: RequestContext, - path: Path, - query: Query, -) -> Result { - let ctx = rqctx.context(); - let path = path.into_inner(); - let query = query.into_inner(); - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code exchange"); - +fn verify_csrf( + request: &RequestInfo, + query: &OAuthAuthzCodeReturnQuery, +) -> Result { // If we are missing the expected state parameter then we can not proceed at all with verifying // this callback request. We also do not have a redirect uri to send the user to so we instead // report unauthorized let attempt_id = query .state + .as_ref() .ok_or_else(|| { tracing::warn!("OAuth callback is missing a state parameter"); unauthorized() @@ -231,8 +208,7 @@ pub async fn authz_code_callback( // The client must present the attempt cookie at a minimum. Without it we are unable to lookup a // login attempt to match against. Without the cookie to verify the state parameter we can not // determine a redirect uri so we instead report unauthorized - let attempt_cookie: Uuid = rqctx - .request + let attempt_cookie: Uuid = request .cookie(LOGIN_ATTEMPT_COOKIE) .ok_or_else(|| { tracing::warn!("OAuth callback is missing a login state cookie"); @@ -253,8 +229,42 @@ pub async fn authz_code_callback( ?attempt_cookie, "OAuth state does not match expected cookie value" ); - return Err(unauthorized()); + Err(unauthorized()) + } else { + Ok(attempt_id) } +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeReturnQuery { + pub state: Option, + pub code: Option, + pub error: Option, +} + +/// Handle return calls from a remote OAuth provider +#[endpoint { + method = GET, + path = "/login/oauth/{provider}/code/callback" +}] +#[instrument(skip(rqctx), fields(request_id = rqctx.request_id), err(Debug))] +pub async fn authz_code_callback( + rqctx: RequestContext, + path: Path, + query: Query, +) -> Result { + let ctx = rqctx.context(); + let path = path.into_inner(); + let query = query.into_inner(); + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code exchange"); + + // Verify and extract the attempt id before performing any work + let attempt_id = verify_csrf(&rqctx.request, &query)?; // We have now verified the attempt id and can use it to look up the rest of the login attempt // material to try and complete the flow @@ -458,19 +468,54 @@ pub async fn authz_code_exchange( #[cfg(test)] mod tests { + use std::net::{Ipv4Addr, SocketAddrV4}; + use chrono::Utc; - use http::header::{LOCATION, SET_COOKIE}; + use dropshot::RequestInfo; + use http::{ + header::{COOKIE, LOCATION, SET_COOKIE}, + HeaderValue, + }; + use hyper::Body; use oauth2::PkceCodeChallenge; use rfd_model::{schema_ext::LoginAttemptState, LoginAttempt}; use uuid::Uuid; use crate::{ context::tests::{mock_context, MockStorage}, - endpoints::login::oauth::OAuthProviderName, + endpoints::login::oauth::{ + code::{verify_csrf, OAuthAuthzCodeReturnQuery, LOGIN_ATTEMPT_COOKIE}, + OAuthProviderName, + }, + util::request::RequestCookies, }; use super::oauth_redirect_response; + #[tokio::test] + async fn test_csrf_check() { + let id = Uuid::new_v4(); + + let mut rq = hyper::Request::new(Body::empty()); + rq.headers_mut().insert( + COOKIE, + HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, id)).unwrap(), + ); + + let request = RequestInfo::new( + &rq, + std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), + ); + + let query = OAuthAuthzCodeReturnQuery { + state: Some(id.to_string()), + code: None, + error: None, + }; + + assert_eq!(id, verify_csrf(&request, &query).unwrap()); + } + #[tokio::test] async fn test_remote_provider_redirect_url() { let storage = MockStorage::new(); diff --git a/rfd-api/src/util.rs b/rfd-api/src/util.rs index 6898224..2466977 100644 --- a/rfd-api/src/util.rs +++ b/rfd-api/src/util.rs @@ -9,7 +9,7 @@ use crate::authn::CloudKmsError; pub mod request { use cookie::Cookie; use dropshot::RequestInfo; - use http::header::SET_COOKIE; + use http::header::COOKIE; pub trait RequestCookies { fn cookie(&self, name: &str) -> Option; @@ -17,7 +17,7 @@ pub mod request { impl RequestCookies for RequestInfo { fn cookie(&self, name: &str) -> Option { - let cookie_header = self.headers().get(SET_COOKIE)?; + let cookie_header = self.headers().get(COOKIE)?; Cookie::split_parse(String::from_utf8(cookie_header.as_bytes().to_vec()).unwrap()) .filter_map(|cookie| match cookie {