Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oauth2 client id enforcement fix #4

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl ApiClient {
Ok(format!("Bearer {}", token))
}
None => {
if let Some(token) = auth.borrow().first_oauth2_token() {
if let Some(token) = auth.borrow_mut().get_token_store().get_first_oauth2_token() {
match token {
Token::OAuth2(token) => Ok(format!("Bearer {}", token)),
_ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)),
Expand Down Expand Up @@ -92,7 +92,7 @@ impl ApiClient {
if let Some(username) = username {
auth_ref.get_token_store().get_oauth2_token(username)
} else {
auth_ref.first_oauth2_token()
auth_ref.get_token_store().get_first_oauth2_token()
}
};
if let Some(token) = token {
Expand Down Expand Up @@ -225,9 +225,8 @@ mod tests {
}

fn mock_auth() -> Auth {
let config = Config::from_env().unwrap();
let config = Config::from_env();
let auth = Auth::new(config)
.unwrap()
.with_token_store(TokenStore::from_file_path(".xurl_test".into()));
auth
}
Expand Down Expand Up @@ -281,7 +280,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config)
.with_url(url)
.with_auth(setup_tests_with_mock_oauth2_token());
Expand All @@ -307,7 +306,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config)
.with_url(url)
.with_auth(setup_tests_with_mock_oauth1_token());
Expand All @@ -332,7 +331,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config)
.with_url(url)
.with_auth(setup_tests_with_mock_app_auth());
Expand All @@ -357,7 +356,7 @@ mod tests {
.create_async()
.await;

let config = Config::from_env().unwrap();
let config = Config::from_env();
let client = ApiClient::new(config.clone())
.with_url(url)
.with_auth(setup_tests_with_mock_oauth2_token());
Expand Down
63 changes: 35 additions & 28 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use std::time::{SystemTime, UNIX_EPOCH};

#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Missing environment variable: {0}")]
MissingEnvVar(&'static str),
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Invalid code: {0}")]
Expand All @@ -47,32 +49,26 @@ pub enum AuthError {
}

pub struct Auth {
client: BasicClient,
token_store: TokenStore,
info_url: String,
client_id: String,
client_secret: String,
auth_url: String,
token_url: String,
redirect_uri: String,
}

impl Auth {
pub fn new(config: Config) -> Result<Self, AuthError> {
let client = BasicClient::new(
ClientId::new(config.client_id),
Some(ClientSecret::new(config.client_secret)),
AuthUrl::new(config.auth_url).map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
Some(
TokenUrl::new(config.token_url)
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
),
)
.set_redirect_uri(
RedirectUrl::new(config.redirect_uri)
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
);

Ok(Self {
client,
pub fn new(config: Config) -> Self {
Self {
token_store: TokenStore::new(),
info_url: config.info_url,
})
client_id: config.client_id,
client_secret: config.client_secret,
auth_url: config.auth_url,
token_url: config.token_url,
redirect_uri: config.redirect_uri,
}
}

#[allow(dead_code)]
Expand Down Expand Up @@ -151,10 +147,27 @@ impl Auth {
}
}

if self.client_id.is_empty() || self.client_secret.is_empty() {
return Err(AuthError::MissingEnvVar("CLIENT_ID or CLIENT_SECRET"));
}

let client = BasicClient::new(
ClientId::new(self.client_id.clone()),
Some(ClientSecret::new(self.client_secret.clone())),
AuthUrl::new(self.auth_url.clone()).map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
Some(
TokenUrl::new(self.token_url.clone())
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
),
)
.set_redirect_uri(
RedirectUrl::new(self.redirect_uri.clone())
.map_err(|e| AuthError::InvalidUrl(e.to_string()))?,
);

let (code_challenge, code_verifier) = PkceCodeChallenge::new_random_sha256();

let (auth_url, _csrf_token) = self
.client
let (auth_url, _csrf_token) = client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("block.read".to_string()))
.add_scope(Scope::new("bookmark.read".to_string()))
Expand Down Expand Up @@ -189,8 +202,7 @@ impl Auth {
.await
.map_err(|e| AuthError::InvalidCode(e))?;

let token = self
.client
let token = client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(code_verifier)
.request_async(async_http_client)
Expand Down Expand Up @@ -235,11 +247,6 @@ impl Auth {
})
}

pub fn first_oauth2_token(&self) -> Option<Token> {
self.token_store.get_first_oauth2_token()
}

#[allow(dead_code)]
pub fn get_token_store(&mut self) -> &mut TokenStore {
&mut self.token_store
}
Expand Down
12 changes: 5 additions & 7 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::error::Error;
use std::env;

#[derive(Clone)]
Expand All @@ -17,10 +16,9 @@ pub struct Config {
}

impl Config {
pub fn from_env() -> Result<Self, Error> {
let client_id = env::var("CLIENT_ID").map_err(|_| Error::MissingEnvVar("CLIENT_ID"))?;
let client_secret =
env::var("CLIENT_SECRET").map_err(|_| Error::MissingEnvVar("CLIENT_SECRET"))?;
pub fn from_env() -> Self {
let client_id = env::var("CLIENT_ID").unwrap_or_default();
let client_secret = env::var("CLIENT_SECRET").unwrap_or_default();
let redirect_uri = env::var("REDIRECT_URI")
.unwrap_or_else(|_| "http://localhost:8080/callback".to_string());
let auth_url =
Expand All @@ -31,14 +29,14 @@ impl Config {
env::var("API_BASE_URL").unwrap_or_else(|_| "https://api.x.com".to_string());
let info_url =
env::var("INFO_URL").unwrap_or_else(|_| format!("{}/2/users/me", api_base_url));
Ok(Self {
Self {
client_id,
client_secret,
redirect_uri,
auth_url,
token_url,
api_base_url,
info_url,
})
}
}
}
3 changes: 0 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ use thiserror::Error;

#[derive(Error, Debug)]
pub enum Error {
#[error("Missing environment variable: {0}")]
MissingEnvVar(&'static str),

#[error("HTTP error: {0}")]
HttpError(#[from] reqwest::Error),

Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use error::Error;
async fn main() -> Result<(), Error> {
let cli = Cli::parse();

let config = Config::from_env()?;
let mut auth = Auth::new(config.clone())?;
let config = Config::from_env();
let mut auth = Auth::new(config.clone());

// Handle auth subcommands
if let Some(Commands::Auth { command }) = cli.command {
Expand Down
Loading