diff --git a/src/api/client.rs b/src/api/client.rs index 59d0a44..21d7d1d 100644 --- a/src/api/client.rs +++ b/src/api/client.rs @@ -44,7 +44,7 @@ impl ApiClient { Ok(format!("Bearer {}", token)) } None => { - if let Some(token) = auth.borrow().first_oauth2_token() { + if let Some(token) = auth.borrow_mut().get_token_store().get_first_oauth2_token() { match token { Token::OAuth2(token) => Ok(format!("Bearer {}", token)), _ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)), @@ -92,7 +92,7 @@ impl ApiClient { if let Some(username) = username { auth_ref.get_token_store().get_oauth2_token(username) } else { - auth_ref.first_oauth2_token() + auth_ref.get_token_store().get_first_oauth2_token() } }; if let Some(token) = token { @@ -225,9 +225,8 @@ mod tests { } fn mock_auth() -> Auth { - let config = Config::from_env().unwrap(); + let config = Config::from_env(); let auth = Auth::new(config) - .unwrap() .with_token_store(TokenStore::from_file_path(".xurl_test".into())); auth } @@ -281,7 +280,7 @@ mod tests { .create_async() .await; - let config = Config::from_env().unwrap(); + let config = Config::from_env(); let client = ApiClient::new(config) .with_url(url) .with_auth(setup_tests_with_mock_oauth2_token()); @@ -307,7 +306,7 @@ mod tests { .create_async() .await; - let config = Config::from_env().unwrap(); + let config = Config::from_env(); let client = ApiClient::new(config) .with_url(url) .with_auth(setup_tests_with_mock_oauth1_token()); @@ -332,7 +331,7 @@ mod tests { .create_async() .await; - let config = Config::from_env().unwrap(); + let config = Config::from_env(); let client = ApiClient::new(config) .with_url(url) .with_auth(setup_tests_with_mock_app_auth()); @@ -357,7 +356,7 @@ mod tests { .create_async() .await; - let config = Config::from_env().unwrap(); + let config = Config::from_env(); let client = ApiClient::new(config.clone()) .with_url(url) .with_auth(setup_tests_with_mock_oauth2_token()); diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 505854c..d85dc3c 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -24,6 +24,8 @@ use std::time::{SystemTime, UNIX_EPOCH}; #[derive(Debug, thiserror::Error)] pub enum AuthError { + #[error("Missing environment variable: {0}")] + MissingEnvVar(&'static str), #[error("Invalid URL: {0}")] InvalidUrl(String), #[error("Invalid code: {0}")] @@ -47,32 +49,26 @@ pub enum AuthError { } pub struct Auth { - client: BasicClient, token_store: TokenStore, info_url: String, + client_id: String, + client_secret: String, + auth_url: String, + token_url: String, + redirect_uri: String, } impl Auth { - pub fn new(config: Config) -> Result { - let client = BasicClient::new( - ClientId::new(config.client_id), - Some(ClientSecret::new(config.client_secret)), - AuthUrl::new(config.auth_url).map_err(|e| AuthError::InvalidUrl(e.to_string()))?, - Some( - TokenUrl::new(config.token_url) - .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, - ), - ) - .set_redirect_uri( - RedirectUrl::new(config.redirect_uri) - .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, - ); - - Ok(Self { - client, + pub fn new(config: Config) -> Self { + Self { token_store: TokenStore::new(), info_url: config.info_url, - }) + client_id: config.client_id, + client_secret: config.client_secret, + auth_url: config.auth_url, + token_url: config.token_url, + redirect_uri: config.redirect_uri, + } } #[allow(dead_code)] @@ -151,10 +147,27 @@ impl Auth { } } + if self.client_id.is_empty() || self.client_secret.is_empty() { + return Err(AuthError::MissingEnvVar("CLIENT_ID or CLIENT_SECRET")); + } + + let client = BasicClient::new( + ClientId::new(self.client_id.clone()), + Some(ClientSecret::new(self.client_secret.clone())), + AuthUrl::new(self.auth_url.clone()).map_err(|e| AuthError::InvalidUrl(e.to_string()))?, + Some( + TokenUrl::new(self.token_url.clone()) + .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, + ), + ) + .set_redirect_uri( + RedirectUrl::new(self.redirect_uri.clone()) + .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, + ); + let (code_challenge, code_verifier) = PkceCodeChallenge::new_random_sha256(); - let (auth_url, _csrf_token) = self - .client + let (auth_url, _csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("block.read".to_string())) .add_scope(Scope::new("bookmark.read".to_string())) @@ -189,8 +202,7 @@ impl Auth { .await .map_err(|e| AuthError::InvalidCode(e))?; - let token = self - .client + let token = client .exchange_code(AuthorizationCode::new(code)) .set_pkce_verifier(code_verifier) .request_async(async_http_client) @@ -235,11 +247,6 @@ impl Auth { }) } - pub fn first_oauth2_token(&self) -> Option { - self.token_store.get_first_oauth2_token() - } - - #[allow(dead_code)] pub fn get_token_store(&mut self) -> &mut TokenStore { &mut self.token_store } diff --git a/src/config.rs b/src/config.rs index 4038152..d34ef0e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,3 @@ -use crate::error::Error; use std::env; #[derive(Clone)] @@ -17,10 +16,9 @@ pub struct Config { } impl Config { - pub fn from_env() -> Result { - let client_id = env::var("CLIENT_ID").map_err(|_| Error::MissingEnvVar("CLIENT_ID"))?; - let client_secret = - env::var("CLIENT_SECRET").map_err(|_| Error::MissingEnvVar("CLIENT_SECRET"))?; + pub fn from_env() -> Self { + let client_id = env::var("CLIENT_ID").unwrap_or_default(); + let client_secret = env::var("CLIENT_SECRET").unwrap_or_default(); let redirect_uri = env::var("REDIRECT_URI") .unwrap_or_else(|_| "http://localhost:8080/callback".to_string()); let auth_url = @@ -31,7 +29,7 @@ impl Config { env::var("API_BASE_URL").unwrap_or_else(|_| "https://api.x.com".to_string()); let info_url = env::var("INFO_URL").unwrap_or_else(|_| format!("{}/2/users/me", api_base_url)); - Ok(Self { + Self { client_id, client_secret, redirect_uri, @@ -39,6 +37,6 @@ impl Config { token_url, api_base_url, info_url, - }) + } } } diff --git a/src/error.rs b/src/error.rs index 0b207e2..961c74c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,9 +3,6 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum Error { - #[error("Missing environment variable: {0}")] - MissingEnvVar(&'static str), - #[error("HTTP error: {0}")] HttpError(#[from] reqwest::Error), diff --git a/src/main.rs b/src/main.rs index d131682..9670686 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,8 +15,8 @@ use error::Error; async fn main() -> Result<(), Error> { let cli = Cli::parse(); - let config = Config::from_env()?; - let mut auth = Auth::new(config.clone())?; + let config = Config::from_env(); + let mut auth = Auth::new(config.clone()); // Handle auth subcommands if let Some(Commands::Auth { command }) = cli.command {