Skip to content

Commit

Permalink
Merge pull request #4 from xdevplatform/santiagomed/auth-fixes
Browse files Browse the repository at this point in the history
oauth2 client id enforcement fix
  • Loading branch information
santiagomed authored Dec 18, 2024
2 parents d85276e + 906512e commit b5c8291
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 48 deletions.
15 changes: 7 additions & 8 deletions src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl ApiClient {
Ok(format!("Bearer {}", token))
}
None => {
if let Some(token) = auth.borrow().first_oauth2_token() {
if let Some(token) = auth.borrow_mut().get_token_store().get_first_oauth2_token() {
match token {
Token::OAuth2(token) => Ok(format!("Bearer {}", token)),
_ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)),
Expand Down Expand Up @@ -92,7 +92,7 @@ impl ApiClient {
if let Some(username) = username {
auth_ref.get_token_store().get_oauth2_token(username)
} else {
auth_ref.first_oauth2_token()
auth_ref.get_token_store().get_first_oauth2_token()
}
};
if let Some(token) = token {
Expand Down Expand Up @@ -225,9 +225,8 @@ mod tests {
}

fn mock_auth() -> Auth {
let config = Config::from_env().unwrap();
let config = Config::from_env();
let auth = Auth::new(config)
.unwrap()
.with_token_store(TokenStore::from_file_path(".xurl_test".into()));
auth
}
Expand Down Expand Up @@ -281,7 +280,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config)
.with_url(url)
.with_auth(setup_tests_with_mock_oauth2_token());
Expand All @@ -307,7 +306,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config)
.with_url(url)
.with_auth(setup_tests_with_mock_oauth1_token());
Expand All @@ -332,7 +331,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config)
.with_url(url)
.with_auth(setup_tests_with_mock_app_auth());
Expand All @@ -357,7 +356,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config.clone())
.with_url(url)
.with_auth(setup_tests_with_mock_oauth2_token());
Expand Down
63 changes: 35 additions & 28 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use std::time::{SystemTime, UNIX_EPOCH};

#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Missing environment variable: {0}")]
MissingEnvVar(&'static str),
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Invalid code: {0}")]
Expand All @@ -47,32 +49,26 @@ pub enum AuthError {
}

pub struct Auth {
client: BasicClient,
token_store: TokenStore,
info_url: String,
client_id: String,
client_secret: String,
auth_url: String,
token_url: String,
redirect_uri: String,
}

impl Auth {
pub fn new(config: Config) -> Result<Self, AuthError> {
let client = BasicClient::new(
ClientId::new(config.client_id),
Some(ClientSecret::new(config.client_secret)),
AuthUrl::new(config.auth_url).map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
Some(
TokenUrl::new(config.token_url)
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
),
)
.set_redirect_uri(
RedirectUrl::new(config.redirect_uri)
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
);

Ok(Self {
client,
pub fn new(config: Config) -> Self {
Self {
token_store: TokenStore::new(),
info_url: config.info_url,
})
client_id: config.client_id,
client_secret: config.client_secret,
auth_url: config.auth_url,
token_url: config.token_url,
redirect_uri: config.redirect_uri,
}
}

#[allow(dead_code)]
Expand Down Expand Up @@ -151,10 +147,27 @@ impl Auth {
}
}

if self.client_id.is_empty() || self.client_secret.is_empty() {
return Err(AuthError::MissingEnvVar("CLIENT_ID or CLIENT_SECRET"));
}

let client = BasicClient::new(
ClientId::new(self.client_id.clone()),
Some(ClientSecret::new(self.client_secret.clone())),
AuthUrl::new(self.auth_url.clone()).map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
Some(
TokenUrl::new(self.token_url.clone())
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
),
)
.set_redirect_uri(
RedirectUrl::new(self.redirect_uri.clone())
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
);

let (code_challenge, code_verifier) = PkceCodeChallenge::new_random_sha256();

let (auth_url, _csrf_token) = self
.client
let (auth_url, _csrf_token) = client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("block.read".to_string()))
.add_scope(Scope::new("bookmark.read".to_string()))
Expand Down Expand Up @@ -189,8 +202,7 @@ impl Auth {
.await
.map_err(|e| AuthError::InvalidCode(e))?;

let token = self
.client
let token = client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(code_verifier)
.request_async(async_http_client)
Expand Down Expand Up @@ -235,11 +247,6 @@ impl Auth {
})
}

pub fn first_oauth2_token(&self) -> Option<Token> {
self.token_store.get_first_oauth2_token()
}

#[allow(dead_code)]
pub fn get_token_store(&mut self) -> &mut TokenStore {
&mut self.token_store
}
Expand Down
12 changes: 5 additions & 7 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::error::Error;
use std::env;

#[derive(Clone)]
Expand All @@ -17,10 +16,9 @@ pub struct Config {
}

impl Config {
pub fn from_env() -> Result<Self, Error> {
let client_id = env::var("CLIENT_ID").map_err(|_| Error::MissingEnvVar("CLIENT_ID"))?;
let client_secret =
env::var("CLIENT_SECRET").map_err(|_| Error::MissingEnvVar("CLIENT_SECRET"))?;
pub fn from_env() -> Self {
let client_id = env::var("CLIENT_ID").unwrap_or_default();
let client_secret = env::var("CLIENT_SECRET").unwrap_or_default();
let redirect_uri = env::var("REDIRECT_URI")
.unwrap_or_else(|_| "http://localhost:8080/callback".to_string());
let auth_url =
Expand All @@ -31,14 +29,14 @@ impl Config {
env::var("API_BASE_URL").unwrap_or_else(|_| "https://api.x.com".to_string());
let info_url =
env::var("INFO_URL").unwrap_or_else(|_| format!("{}/2/users/me", api_base_url));
Ok(Self {
Self {
client_id,
client_secret,
redirect_uri,
auth_url,
token_url,
api_base_url,
info_url,
})
}
}
}
3 changes: 0 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ use thiserror::Error;

#[derive(Error, Debug)]
pub enum Error {
#[error("Missing environment variable: {0}")]
MissingEnvVar(&'static str),

#[error("HTTP error: {0}")]
HttpError(#[from] reqwest::Error),

Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use error::Error;
async fn main() -> Result<(), Error> {
let cli = Cli::parse();

let config = Config::from_env()?;
let mut auth = Auth::new(config.clone())?;
let config = Config::from_env();
let mut auth = Auth::new(config.clone());

// Handle auth subcommands
if let Some(Commands::Auth { command }) = cli.command {
Expand Down

0 comments on commit b5c8291

Please sign in to comment.