From ece9f8a9b9b6ead400890a8364ccf4fb79958841 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Tue, 26 Sep 2023 13:26:30 -0500 Subject: [PATCH] Expand csrf test --- rfd-api/src/endpoints/login/oauth/code.rs | 54 +++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/rfd-api/src/endpoints/login/oauth/code.rs b/rfd-api/src/endpoints/login/oauth/code.rs index 3103e9e0..cdcfde9d 100644 --- a/rfd-api/src/endpoints/login/oauth/code.rs +++ b/rfd-api/src/endpoints/login/oauth/code.rs @@ -271,6 +271,7 @@ pub async fn authz_code_callback( ) } +#[instrument(skip(ctx, code), err(Debug))] pub async fn authz_code_callback_op( ctx: &ApiContext, attempt_id: &Uuid, @@ -530,19 +531,66 @@ mod tests { COOKIE, HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, id)).unwrap(), ); - - let request = RequestInfo::new( + let with_valid_cookie = RequestInfo::new( &rq, std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), ); + let query = OAuthAuthzCodeReturnQuery { + state: Some(id.to_string()), + code: None, + error: None, + }; + assert_eq!(id, verify_csrf(&with_valid_cookie, &query).unwrap()); + let query = OAuthAuthzCodeReturnQuery { + state: None, + code: None, + error: None, + }; + assert_eq!( + StatusCode::UNAUTHORIZED, + verify_csrf(&with_valid_cookie, &query) + .unwrap_err() + .status_code + ); + + let mut rq = hyper::Request::new(Body::empty()); + rq.headers_mut().insert( + COOKIE, + HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, Uuid::new_v4())).unwrap(), + ); + let with_invalid_cookie = RequestInfo::new( + &rq, + std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), + ); let query = OAuthAuthzCodeReturnQuery { state: Some(id.to_string()), code: None, error: None, }; + assert_eq!( + StatusCode::UNAUTHORIZED, + verify_csrf(&with_invalid_cookie, &query) + .unwrap_err() + .status_code + ); - assert_eq!(id, verify_csrf(&request, &query).unwrap()); + let rq = hyper::Request::new(Body::empty()); + let with_missing_cookie = RequestInfo::new( + &rq, + std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), + ); + let query = OAuthAuthzCodeReturnQuery { + state: Some(id.to_string()), + code: None, + error: None, + }; + assert_eq!( + StatusCode::UNAUTHORIZED, + verify_csrf(&with_missing_cookie, &query) + .unwrap_err() + .status_code + ); } #[tokio::test]