Skip to content

Commit

Permalink
Add github oauth provider
Browse files Browse the repository at this point in the history
  • Loading branch information
augustuswm committed Sep 11, 2023
1 parent d4102c5 commit d73567d
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 75 deletions.
14 changes: 9 additions & 5 deletions rfd-api/src/authn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use dropshot::{HttpError, RequestContext, SharedExtractor};
use dropshot_authorization_header::bearer::BearerAuth;
use thiserror::Error;

use crate::{context::ApiContext, util::response::unauthorized, authn::key::RawApiKey};
use crate::{authn::key::RawApiKey, context::ApiContext, util::response::unauthorized};

use self::{jwt::Jwt, key::EncryptedApiKey};

Expand Down Expand Up @@ -54,10 +54,14 @@ impl AuthToken {
Err(err) => {
tracing::debug!(?err, "Token is not a JWT, falling back to API key");

Ok(AuthToken::ApiKey(RawApiKey::new(token).encrypt(&*ctx.secrets.encryptor).await.map_err(|err| {
tracing::error!(?err, "Failed to encrypt authn token");
AuthError::FailedToExtract
})?,
Ok(AuthToken::ApiKey(
RawApiKey::new(token)
.encrypt(&*ctx.secrets.encryptor)
.await
.map_err(|err| {
tracing::error!(?err, "Failed to encrypt authn token");
AuthError::FailedToExtract
})?,
))
}
}
Expand Down
7 changes: 5 additions & 2 deletions rfd-api/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
key::{key_to_encryptor, KeyEncryptor},
AuthError, AuthToken,
},
config::{JwtConfig, PermissionsConfig, AsymmetricKey},
config::{AsymmetricKey, JwtConfig, PermissionsConfig},
email_validator::EmailValidator,
endpoints::login::{
oauth::{OAuthProvider, OAuthProviderError, OAuthProviderFn, OAuthProviderName},
Expand Down Expand Up @@ -189,7 +189,7 @@ impl ApiContext {
permissions: PermissionsContext {
default: permissions.default.into(),
},

jwt: JwtContext {
default_expiration: jwt.default_expiration,
max_expiration: jwt.max_expiration,
Expand Down Expand Up @@ -393,6 +393,9 @@ impl ApiContext {

tracing::info!("Check for existing users matching the requested external id");

// TODO: Handle user merging. When a user signs in with a verified email that we have
// already seen how do we handle merges?

let api_user_providers = self
.list_api_user_provider(filter, &ListPagination::latest())
.await?;
Expand Down
6 changes: 4 additions & 2 deletions rfd-api/src/endpoints/login/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ pub struct UserInfo {

#[derive(Debug, Error)]
pub enum UserInfoError {
#[error("Failed to deserialize user info {0}")]
Deserialize(#[from] serde_json::Error),
#[error("Failed to create user info request {0}")]
Http(#[from] http::Error),
#[error("Failed to send user info request {0}")]
Hyper(#[from] hyper::Error),
#[error("Failed to deserialize user info {0}")]
Deserialize(#[from] serde_json::Error),
#[error("User information is missing")]
MissingUserInfoData(String),
}

#[async_trait]
Expand Down
79 changes: 37 additions & 42 deletions rfd-api/src/endpoints/login/oauth/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ use uuid::Uuid;

use super::{OAuthProviderNameParam, UserInfoProvider};
use crate::{
authn::key::RawApiKey,
context::ApiContext,
endpoints::login::LoginError,
error::ApiError,
util::response::{bad_request, client_error, internal_error, to_internal_error}, authn::key::RawApiKey,
util::response::{bad_request, client_error, internal_error, to_internal_error},
};

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
Expand Down Expand Up @@ -144,31 +145,29 @@ pub async fn authz_code_callback(

let original_attempt = match query.state {
Some(state) => {

// Attempt to extract the request id and csrf token from the state parameter. These
// must both be present
if let Some((id, csrf)) = state
.split_once(":")
.and_then(|(id, csrf)| id.parse::<Uuid>().ok().map(|id| (id, csrf))) {

// Look up the login attempt referenced in the state and verify that has the
// csrf value still matches
ctx
.get_login_attempt(&id)
.await
.map_err(to_internal_error)?
.and_then(|attempt| {
if attempt.state.as_ref().map(|s| s == csrf).unwrap_or(false) {
Some(attempt)
} else {
None
}
})
} else {
None
}
},
None => None
.and_then(|(id, csrf)| id.parse::<Uuid>().ok().map(|id| (id, csrf)))
{
// Look up the login attempt referenced in the state and verify that has the
// csrf value still matches
ctx.get_login_attempt(&id)
.await
.map_err(to_internal_error)?
.and_then(|attempt| {
if attempt.state.as_ref().map(|s| s == csrf).unwrap_or(false) {
Some(attempt)
} else {
None
}
})
} else {
None
}
}
None => None,
};

// If an attempt could not be found than the server needs to fail with an internal server error.
Expand All @@ -177,36 +176,33 @@ pub async fn authz_code_callback(
// are fully dependent on the state parameter. Instead a very short lived cookie should be used
// to track the attempt that the client is making so that we can restore the attempt without
// use of the state parameter
let mut attempt = original_attempt
.ok_or_else(|| internal_error("Failed to load matching login attempt"))?;
let mut attempt =
original_attempt.ok_or_else(|| internal_error("Failed to load matching login attempt"))?;

attempt = match (query.code, query.error) {
(Some(code), None) => {

// Store the authorization code returned by the underlying OAuth provider and transition the
// attempt to the awaiting state
ctx
.set_login_provider_authz_code(attempt, code.to_string())
ctx.set_login_provider_authz_code(attempt, code.to_string())
.await
.map_err(to_internal_error)?
}
(code, error) => {

// Store the provider return error for future debugging, but if an error has been
// returned or there is a missing code, then we can not report a successful process
attempt.provider_authz_code = code;

// TODO: Specialize the returned error
ctx.fail_login_attempt(attempt, Some("server_error"), error.as_deref()).await.map_err(to_internal_error)?
ctx.fail_login_attempt(attempt, Some("server_error"), error.as_deref())
.await
.map_err(to_internal_error)?
}
};

// Redirect back to the original authenticator
http_response_temporary_redirect(attempt.callback_url())
}



#[derive(Debug, Deserialize, JsonSchema, Serialize)]
pub struct OAuthAuthzCodeExchangeQuery {
pub client_id: Uuid,
Expand Down Expand Up @@ -248,7 +244,10 @@ pub async fn authz_code_exchange(
return Err(bad_request("Invalid grant type"));
}

let client_secret = RawApiKey::new(query.client_secret.clone()).encrypt(&*ctx.secrets.encryptor).await.map_err(to_internal_error)?;
let client_secret = RawApiKey::new(query.client_secret.clone())
.encrypt(&*ctx.secrets.encryptor)
.await
.map_err(to_internal_error)?;

ctx.get_oauth_client(&query.client_id)
.await
Expand All @@ -259,10 +258,7 @@ pub async fn authz_code_exchange(
if client.is_secret_valid(&client_secret.encrypted) {
Ok(client)
} else {
Err(client_error(
StatusCode::UNAUTHORIZED,
"Invalid secret",
))
Err(client_error(StatusCode::UNAUTHORIZED, "Invalid secret"))
}
} else {
Err(client_error(
Expand Down Expand Up @@ -296,12 +292,11 @@ pub async fn authz_code_exchange(

// Exchange the stored authorization code with the remote provider for a remote access token
let client = provider.as_client().map_err(to_internal_error)?;
let mut request = client
.exchange_code(AuthorizationCode::new(
attempt.provider_authz_code.ok_or_else(|| {
internal_error("Expected authorization code to exist due to attempt state")
})?,
));
let mut request = client.exchange_code(AuthorizationCode::new(
attempt.provider_authz_code.ok_or_else(|| {
internal_error("Expected authorization code to exist due to attempt state")
})?,
));

if let Some(pkce_verifier) = attempt.provider_pkce_verifier {
request = request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier))
Expand Down
117 changes: 117 additions & 0 deletions rfd-api/src/endpoints/login/oauth/github.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use std::fmt;

use hyper::body::Bytes;
use serde::{Deserialize, Serialize};

use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError};

use super::{ExtractUserInfo, OAuthProvider, OAuthProviderName};

pub struct GitHubOAuthProvider {
public: GitHubPublicProvider,
private: Option<GitHubPrivateProvider>,
}

impl fmt::Debug for GitHubOAuthProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GitHubOAuthProvider").finish()
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct GitHubPublicProvider {
client_id: String,
}

pub struct GitHubPrivateProvider {
client_secret: String,
}

impl GitHubOAuthProvider {
pub fn new(client_id: String, client_secret: String) -> Self {
Self {
public: GitHubPublicProvider { client_id },
private: Some(GitHubPrivateProvider { client_secret }),
}
}
}

#[derive(Debug, Deserialize)]
struct GitHubUser {
id: String,
}

#[derive(Debug, Deserialize)]
struct GitHubUserEmails {
email: String,
verified: bool,
// TODO: Add ability to mask non-visible emails?
_visibility: Option<String>,
}

impl ExtractUserInfo for GitHubOAuthProvider {
// There should always be as many entries in the data list as there are endpoints. This should
// be changed in the future to be a static check
fn extract_user_info(&self, data: &[Bytes]) -> Result<UserInfo, UserInfoError> {
let user: GitHubUser = serde_json::from_slice(&data[1])?;

let remote_emails: Vec<GitHubUserEmails> = serde_json::from_slice(&data[1])?;
let verified_emails = remote_emails
.into_iter()
.filter(|email| email.verified)
.map(|e| e.email)
.collect::<Vec<_>>();

Ok(UserInfo {
external_id: ExternalUserId::GitHub(user.id),
verified_emails,
})
}
}

impl OAuthProvider for GitHubOAuthProvider {
fn name(&self) -> OAuthProviderName {
OAuthProviderName::GitHub
}

fn scopes(&self) -> Vec<&str> {
vec!["user:email"]
}

fn client_id(&self) -> &str {
&self.public.client_id
}

fn client_secret(&self) -> Option<&str> {
self.private
.as_ref()
.map(|private| private.client_secret.as_str())
}

fn user_info_endpoints(&self) -> Vec<&str> {
vec![
"https://api.github.com/user",
"https://api.github.com/user/emails",
]
}

fn device_code_endpoint(&self) -> &str {
"https://github.com/login/device/code"
}

fn auth_url_endpoint(&self) -> &str {
"https://github.com/login/oauth/authorize"
}

fn token_exchange_content_type(&self) -> &str {
"application/x-www-form-urlencoded"
}

fn token_exchange_endpoint(&self) -> &str {
"https://github.com/login/oauth/access_token"
}

fn supports_pkce(&self) -> bool {
true
}
}
16 changes: 8 additions & 8 deletions rfd-api/src/endpoints/login/oauth/google.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt;

use hyper::body::Bytes;
use serde::{Deserialize, Serialize};

use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError};
Expand Down Expand Up @@ -43,8 +44,10 @@ struct GoogleUserInfo {
}

impl ExtractUserInfo for GoogleOAuthProvider {
fn extract_user_info(&self, data: &[u8]) -> Result<UserInfo, UserInfoError> {
let remote_info: GoogleUserInfo = serde_json::from_slice(data)?;
// There should always be as many entries in the data list as there are endpoints. This should
// be changed in the future to be a static check
fn extract_user_info(&self, data: &[Bytes]) -> Result<UserInfo, UserInfoError> {
let remote_info: GoogleUserInfo = serde_json::from_slice(&data[0])?;
let verified_emails = if remote_info.email_verified {
vec![remote_info.email]
} else {
Expand All @@ -64,10 +67,7 @@ impl OAuthProvider for GoogleOAuthProvider {
}

fn scopes(&self) -> Vec<&str> {
vec![
"https://www.googleapis.com/auth/userinfo.email",
"openid"
]
vec!["https://www.googleapis.com/auth/userinfo.email", "openid"]
}

fn client_id(&self) -> &str {
Expand All @@ -80,8 +80,8 @@ impl OAuthProvider for GoogleOAuthProvider {
.map(|private| private.client_secret.as_str())
}

fn user_info_endpoint(&self) -> &str {
"https://openidconnect.googleapis.com/v1/userinfo"
fn user_info_endpoints(&self) -> Vec<&str> {
vec!["https://openidconnect.googleapis.com/v1/userinfo"]
}

fn device_code_endpoint(&self) -> &str {
Expand Down
Loading

0 comments on commit d73567d

Please sign in to comment.