Skip to content

Commit

Permalink
Add tests for grant code checks. Create structured oauth error
Browse files Browse the repository at this point in the history
  • Loading branch information
augustuswm committed Sep 27, 2023
1 parent f048c84 commit 29c3d1b
Showing 1 changed file with 282 additions and 24 deletions.
306 changes: 282 additions & 24 deletions rfd-api/src/endpoints/login/oauth/code.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
use chrono::{Duration, Utc};
use dropshot::{
endpoint, http_response_temporary_redirect, HttpError, HttpResponseOk,
Expand All @@ -16,6 +17,7 @@ use oauth2::{
use rfd_model::{schema_ext::LoginAttemptState, LoginAttempt, NewLoginAttempt, OAuthClient};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::fmt::Debug;
use std::ops::Add;
use tap::TapFallible;
Expand All @@ -39,6 +41,38 @@ use crate::{

static LOGIN_ATTEMPT_COOKIE: &str = "__rfd_login";

#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)]
struct OAuthError {
error: OAuthErrorCode,
#[serde(skip_serializing_if = "Option::is_none")]
error_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
error_uri: Option<String>,
}

#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum OAuthErrorCode {
InvalidRequest,
InvalidClient,
InvalidGrant,
UnauthorizedClient,
UnsupportedGrantType,
InvalidScope,
}

impl From<OAuthError> for HttpError {
fn from(value: OAuthError) -> Self {
let serialized = serde_json::to_string(&value).unwrap();
HttpError {
status_code: StatusCode::BAD_REQUEST,
error_code: None,
external_message: serialized.clone(),
internal_message: serialized,
}
}
}

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
pub struct OAuthAuthzCodeQuery {
pub client_id: Uuid,
Expand Down Expand Up @@ -389,10 +423,15 @@ pub async fn authz_code_exchange(
.await
.map_err(to_internal_error)?
// TODO: Bad request is ok, but a json body with invalid_grant should be returned
.ok_or_else(|| bad_request("Invalid code".to_string()))?;
.ok_or_else(|| bad_request("invalid_grant".to_string()))?;

// Verify that the login attempt is valid and matches the submitted client credentials
verify_login_attempt(&attempt, &body.client_id, &body.redirect_uri)?;
verify_login_attempt(
&attempt,
&body.client_id,
&body.redirect_uri,
body.pkce_verifier.as_deref(),
)?;

tracing::debug!("Verified login attempt");

Expand Down Expand Up @@ -440,7 +479,7 @@ async fn authorize_exchange(
// Verify that we received the expected grant type
if grant_type != "authorization_code" {
// TODO: Needs to be json body
return Err(bad_request("Invalid grant type"));
return Err(bad_request("unsupported_grant_type"));
}

let client_secret = RawApiKey::try_from(client_secret).map_err(|err| {
Expand All @@ -452,7 +491,7 @@ async fn authorize_exchange(

if !client.is_secret_valid(&client_secret, &*ctx.secrets.signer) {
// TODO: Change this to a bad request with invalid_client ?
Err(client_error(StatusCode::UNAUTHORIZED, "Invalid secret"))
Err(client_error(StatusCode::UNAUTHORIZED, "invalid_client"))
} else {
Ok(())
}
Expand All @@ -462,23 +501,57 @@ fn verify_login_attempt(
attempt: &LoginAttempt,
client_id: &Uuid,
redirect_uri: &str,
) -> Result<(), HttpError> {
pkce_verifier: Option<&str>,
) -> Result<(), OAuthError> {
if attempt.client_id != *client_id {
// TODO: Bad request is ok, but a json body with invalid_grant should be returned
Err(bad_request("Invalid client id".to_string()))
Err(OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Invalid client id".to_string()),
error_uri: None,
})
} else if attempt.redirect_uri != redirect_uri {
// TODO: Bad request is ok, but a json body with invalid_grant should be returned
Err(bad_request("Invalid redirect uri".to_string()))
Err(OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Invalid redirect uri".to_string()),
error_uri: None,
})
} else if attempt.attempt_state != LoginAttemptState::RemoteAuthenticated {
// TODO: Bad request is ok, but a json body with invalid_client should be returned
Err(bad_request("Invalid login state".to_string()))
Err(OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Grant is in an invalid state".to_string()),
error_uri: None,
})
} else if attempt.expires_at.map(|t| t <= Utc::now()).unwrap_or(true) {
// TODO: Bad request is ok, but a json body with invalid_client should be returned
Err(bad_request("Login attempt expired".to_string()))
Err(OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Grant has expired".to_string()),
error_uri: None,
})
} else {
// TODO: Perform pkce check

Ok(())
match (attempt.pkce_challenge.as_deref(), pkce_verifier) {
(Some(_), None) => Err(OAuthError {
error: OAuthErrorCode::InvalidRequest,
error_description: Some("Missing pkce verifier".to_string()),
error_uri: None,
}),
(Some(challenge), Some(verifier)) => {
let mut hasher = Sha256::new();
hasher.update(verifier);
let hash = hasher.finalize();
let computed_challenge = BASE64_URL_SAFE_NO_PAD.encode(hash);

if challenge == computed_challenge {
Ok(())
} else {
Err(OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Invalid pkce verifier".to_string()),
error_uri: None,
})
}
}
(None, _) => Ok(()),
}
}
}

Expand Down Expand Up @@ -534,10 +607,11 @@ async fn fetch_user_info(
mod tests {
use std::{
net::{Ipv4Addr, SocketAddrV4},
ops::Add,
sync::{Arc, Mutex},
};

use chrono::Utc;
use chrono::{Duration, Utc};
use dropshot::RequestInfo;
use http::{
header::{COOKIE, LOCATION, SET_COOKIE},
Expand All @@ -560,7 +634,10 @@ mod tests {
ApiContext,
},
endpoints::login::oauth::{
code::{verify_csrf, OAuthAuthzCodeReturnQuery, LOGIN_ATTEMPT_COOKIE},
code::{
verify_csrf, verify_login_attempt, OAuthAuthzCodeReturnQuery, OAuthError,
OAuthErrorCode, LOGIN_ATTEMPT_COOKIE,
},
OAuthProviderName,
},
};
Expand Down Expand Up @@ -1180,13 +1257,194 @@ mod tests {
);
}

#[tokio::test]
async fn test_exchange_fails_on_invalid_code() {
unimplemented!()
}

#[tokio::test]
async fn test_login_attempt_verification() {
unimplemented!()
let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
let attempt = LoginAttempt {
id: Uuid::new_v4(),
attempt_state: LoginAttemptState::RemoteAuthenticated,
client_id: Uuid::new_v4(),
redirect_uri: "https://test.oxeng.dev/callback".to_string(),
state: Some("ox_state".to_string()),
pkce_challenge: Some(challenge.as_str().to_string()),
pkce_challenge_method: Some("S256".to_string()),
authz_code: None,
expires_at: Some(Utc::now().add(Duration::seconds(60))),
error: None,
provider: "google".to_string(),
provider_pkce_verifier: Some("rfd_verifier".to_string()),
provider_authz_code: None,
provider_error: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};

let bad_client_id = LoginAttempt {
client_id: Uuid::new_v4(),
..attempt.clone()
};

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Invalid client id".to_string()),
error_uri: None,
},
verify_login_attempt(
&bad_client_id,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap_err()
);

let bad_redirect_uri = LoginAttempt {
redirect_uri: "https://bad.oxeng.dev/callback".to_string(),
..attempt.clone()
};

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Invalid redirect uri".to_string()),
error_uri: None,
},
verify_login_attempt(
&bad_redirect_uri,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap_err()
);

let unconfirmed_state = LoginAttempt {
attempt_state: LoginAttemptState::New,
..attempt.clone()
};

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Grant is in an invalid state".to_string()),
error_uri: None,
},
verify_login_attempt(
&unconfirmed_state,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap_err()
);

let already_used_state = LoginAttempt {
attempt_state: LoginAttemptState::Complete,
..attempt.clone()
};

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Grant is in an invalid state".to_string()),
error_uri: None,
},
verify_login_attempt(
&already_used_state,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap_err()
);

let failed_state = LoginAttempt {
attempt_state: LoginAttemptState::Failed,
..attempt.clone()
};

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Grant is in an invalid state".to_string()),
error_uri: None,
},
verify_login_attempt(
&failed_state,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap_err()
);

let expired = LoginAttempt {
expires_at: Some(Utc::now()),
..attempt.clone()
};

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Grant has expired".to_string()),
error_uri: None,
},
verify_login_attempt(
&expired,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap_err()
);

let missing_pkce = LoginAttempt { ..attempt.clone() };

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidRequest,
error_description: Some("Missing pkce verifier".to_string()),
error_uri: None,
},
verify_login_attempt(
&missing_pkce,
&attempt.client_id,
&attempt.redirect_uri,
None,
)
.unwrap_err()
);

let invalid_pkce = LoginAttempt {
pkce_challenge: Some("no-the-correct-value".to_string()),
..attempt.clone()
};

assert_eq!(
OAuthError {
error: OAuthErrorCode::InvalidGrant,
error_description: Some("Invalid pkce verifier".to_string()),
error_uri: None,
},
verify_login_attempt(
&invalid_pkce,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap_err()
);

assert_eq!(
(),
verify_login_attempt(
&attempt,
&attempt.client_id,
&attempt.redirect_uri,
Some(verifier.secret().as_str()),
)
.unwrap()
);
}
}

0 comments on commit 29c3d1b

Please sign in to comment.