From bbff701e36284b3dd9732e6b6f93d559feff35a8 Mon Sep 17 00:00:00 2001 From: Santiago Medina Rolong Date: Sat, 21 Dec 2024 00:54:43 -0800 Subject: [PATCH] Add OAuth2 token validation and refresh logic in ApiClient and Auth modules. --- .vscode/launch.json | 16 +++ src/api/client.rs | 86 ++++++++++----- src/auth/mod.rs | 232 ++++++++++++++++++++++++++++++---------- src/auth/token_store.rs | 23 +++- src/main.rs | 34 +++--- 5 files changed, 289 insertions(+), 102 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..10efcb2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug", + "program": "${workspaceFolder}/", + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/src/api/client.rs b/src/api/client.rs index 21d7d1d..9dbc96f 100644 --- a/src/api/client.rs +++ b/src/api/client.rs @@ -6,7 +6,7 @@ use reqwest::RequestBuilder; use reqwest::{Client, Method}; use serde_json::Value; use std::cell::RefCell; - +use std::time::{SystemTime, UNIX_EPOCH}; pub struct ApiClient { url: String, client: Client, @@ -33,30 +33,57 @@ impl ApiClient { self } + /// Validate the OAuth2 token and refresh it if it is expired + async fn validate_and_refresh_oauth2_token( + &self, + auth: &RefCell, + token: Token, + username: Option<&str>, + ) -> Result { + match token { + Token::OAuth2(token) => { + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + if current_time > token.expiration_time { + let new_token = auth.borrow_mut().oauth2_refresh_token(username).await?; + Ok(format!("Bearer {}", new_token)) + } else { + Ok(format!("Bearer {}", token.access_token)) + } + } + _ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)), + } + } + + /// Get the OAuth2 token from the token store, validate it and refresh it if it is expired async fn get_oauth2_token( &self, auth: &RefCell, username: Option<&str>, ) -> Result { - match username { - Some(username) => { - let token = auth.borrow_mut().oauth2(Some(username)).await?; - Ok(format!("Bearer {}", token)) + let token = { + let mut auth_ref = auth.borrow_mut(); + match username { + Some(username) => auth_ref.get_token_store().get_oauth2_token(username), + None => auth_ref.get_token_store().get_first_oauth2_token(), + } + }; + match token { + Some(token) => { + self.validate_and_refresh_oauth2_token(auth, token, username) + .await } None => { - if let Some(token) = auth.borrow_mut().get_token_store().get_first_oauth2_token() { - match token { - Token::OAuth2(token) => Ok(format!("Bearer {}", token)), - _ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)), - } - } else { - let token = auth.borrow_mut().oauth2(None).await?; - Ok(format!("Bearer {}", token)) - } + let token = auth.borrow_mut().oauth2(username).await?; + Ok(format!("Bearer {}", token)) } } } + /// Get the auth header for the request async fn get_auth_header( &self, method: &str, @@ -69,7 +96,7 @@ impl ApiClient { None => return Ok("".to_string()), }; - match auth_type.as_deref() { + match auth_type { Some("app") => { if let Some(token) = auth.borrow().bearer_token() { Ok(format!("Bearer {}", token)) @@ -90,16 +117,16 @@ impl ApiClient { let token = { let mut auth_ref = auth.borrow_mut(); if let Some(username) = username { + // Username passed, we need to get the token for the specific username auth_ref.get_token_store().get_oauth2_token(username) } else { + // No username passed, we need to get the first oauth2 token auth_ref.get_token_store().get_first_oauth2_token() } }; if let Some(token) = token { - match token { - Token::OAuth2(token) => Ok(format!("Bearer {}", token)), - _ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)), - } + self.validate_and_refresh_oauth2_token(auth, token, username) + .await } else { let oauth1_result = { let auth_ref = auth.borrow(); @@ -121,7 +148,7 @@ impl ApiClient { } } -pub async fn build_request( + pub async fn build_request( &self, method: &str, endpoint: &str, @@ -188,8 +215,8 @@ pub async fn build_request( let response = request_builder.send().await?; if verbose { - println!("Request: {:#?}", req); - println!("Response: {:#?}", response) + println!("{:#?}", req); + println!("{:#?}", response) } let status = response.status(); @@ -201,7 +228,7 @@ pub async fn build_request( } else { Ok(res) } - }, + } Err(_) => { let status = status.to_string(); Err(Error::ApiError(serde_json::json!({ @@ -226,15 +253,22 @@ mod tests { fn mock_auth() -> Auth { let config = Config::from_env(); - let auth = Auth::new(config) - .with_token_store(TokenStore::from_file_path(".xurl_test".into())); + let auth = + Auth::new(config).with_token_store(TokenStore::from_file_path(".xurl_test".into())); auth } fn setup_tests_with_mock_oauth2_token() -> Auth { let mut auth = mock_auth(); let token_store = auth.get_token_store(); - token_store.save_oauth2_token("test", "fake_token").unwrap(); + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 7200; + token_store + .save_oauth2_token("test", "fake_token", "fake_refresh_token", current_time) + .unwrap(); auth } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index d85dc3c..bc125a8 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -9,9 +9,11 @@ use crate::config::Config; use oauth2::basic::BasicClient; use oauth2::reqwest::async_http_client; +use oauth2::RefreshToken; use oauth2::{ - AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, - Scope, TokenResponse, TokenUrl, + basic::BasicTokenType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, + EmptyExtraTokenFields, PkceCodeChallenge, RedirectUrl, Scope, StandardTokenResponse, + TokenResponse, TokenUrl, }; use base64::{engine::general_purpose::STANDARD, Engine}; @@ -20,7 +22,7 @@ use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; use rand::Rng; use sha1::Sha1; use std::collections::BTreeMap; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[derive(Debug, thiserror::Error)] pub enum AuthError { @@ -136,7 +138,17 @@ impl Auth { if let Some(username) = username { if let Some(token) = self.token_store.get_oauth2_token(username) { match token { - Token::OAuth2(token) => return Ok(token), + Token::OAuth2(token) => { + if SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + > token.expiration_time + { + return self.oauth2_refresh_token(Some(username)).await; + } + return Ok(token.access_token); + } _ => return Err(AuthError::WrongTokenFoundInStore), } } else { @@ -151,47 +163,13 @@ impl Auth { return Err(AuthError::MissingEnvVar("CLIENT_ID or CLIENT_SECRET")); } - let client = BasicClient::new( - ClientId::new(self.client_id.clone()), - Some(ClientSecret::new(self.client_secret.clone())), - AuthUrl::new(self.auth_url.clone()).map_err(|e| AuthError::InvalidUrl(e.to_string()))?, - Some( - TokenUrl::new(self.token_url.clone()) - .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, - ), - ) - .set_redirect_uri( - RedirectUrl::new(self.redirect_uri.clone()) - .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, - ); + let client = self.create_oauth2_client().await?; let (code_challenge, code_verifier) = PkceCodeChallenge::new_random_sha256(); let (auth_url, _csrf_token) = client .authorize_url(CsrfToken::new_random) - .add_scope(Scope::new("block.read".to_string())) - .add_scope(Scope::new("bookmark.read".to_string())) - .add_scope(Scope::new("dm.read".to_string())) - .add_scope(Scope::new("follows.read".to_string())) - .add_scope(Scope::new("like.read".to_string())) - .add_scope(Scope::new("list.read".to_string())) - .add_scope(Scope::new("mute.read".to_string())) - .add_scope(Scope::new("space.read".to_string())) - .add_scope(Scope::new("tweet.read".to_string())) - .add_scope(Scope::new("timeline.read".to_string())) - .add_scope(Scope::new("users.read".to_string())) - .add_scope(Scope::new("block.write".to_string())) - .add_scope(Scope::new("bookmark.write".to_string())) - .add_scope(Scope::new("dm.write".to_string())) - .add_scope(Scope::new("follows.write".to_string())) - .add_scope(Scope::new("like.write".to_string())) - .add_scope(Scope::new("list.write".to_string())) - .add_scope(Scope::new("mute.write".to_string())) - .add_scope(Scope::new("tweet.write".to_string())) - .add_scope(Scope::new("tweet.moderate.write".to_string())) - .add_scope(Scope::new("timeline.write".to_string())) - .add_scope(Scope::new("offline.access".to_string())) - .add_scope(Scope::new("media.write".to_string())) + .add_scopes(OAuth2Scopes::all()) .set_pkce_challenge(code_challenge) .url(); @@ -215,26 +193,57 @@ impl Auth { _ => AuthError::InvalidToken(e.to_string()), })?; - let token = token.access_token().secret().to_string(); + let username = self + .fetch_username(&token.access_token().secret().to_string()) + .await?; + self.save_token_data(&username, &token)?; - let username = reqwest::Client::new() - .get(&self.info_url) - .header("Authorization", format!("Bearer {}", token)) - .send() - .await - .map_err(|e| AuthError::NetworkError(e.to_string()))? - .json::() - .await - .map_err(|e| AuthError::NetworkError(e.to_string()))?; + Ok(token.access_token().secret().to_string()) + } - let username = username["data"]["username"] - .as_str() - .ok_or_else(|| AuthError::NetworkError("Missing username field".to_string()))? - .to_string(); + pub async fn oauth2_refresh_token( + &mut self, + username: Option<&str>, + ) -> Result { + let refresh_token = if let Some(username) = username { + if let Some(token) = self.token_store.get_oauth2_token(username) { + match token { + Token::OAuth2(token) => token.refresh_token, + _ => return Err(AuthError::WrongTokenFoundInStore), + } + } else { + return Err(AuthError::TokenNotFound(format!( + "No cached OAuth2 token found for {}", + username + ))); + } + } else { + let token = self.token_store.get_first_oauth2_token(); + if let Some(token) = token { + match token { + Token::OAuth2(token) => token.refresh_token, + _ => return Err(AuthError::WrongTokenFoundInStore), + } + } else { + return Err(AuthError::TokenNotFound( + "No OAuth2 tokens found".to_string(), + )); + } + }; + let client = self.create_oauth2_client().await?; + + let token = client + .exchange_refresh_token(&RefreshToken::new(refresh_token)) + .request_async(async_http_client) + .await + .map_err(|e| AuthError::InvalidToken(e.to_string()))?; - self.token_store.save_oauth2_token(&username, &token)?; + let username = self + .fetch_username(&token.access_token().secret().to_string()) + .await?; + self.save_token_data(&username, &token)?; - Ok(token) + Ok(token.access_token().secret().to_string()) } pub fn bearer_token(&self) -> Option { @@ -250,6 +259,117 @@ impl Auth { pub fn get_token_store(&mut self) -> &mut TokenStore { &mut self.token_store } + + async fn create_oauth2_client(&self) -> Result { + let client = BasicClient::new( + ClientId::new(self.client_id.clone()), + Some(ClientSecret::new(self.client_secret.clone())), + AuthUrl::new(self.auth_url.clone()) + .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, + Some( + TokenUrl::new(self.token_url.clone()) + .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, + ), + ) + .set_redirect_uri( + RedirectUrl::new(self.redirect_uri.clone()) + .map_err(|e| AuthError::InvalidUrl(e.to_string()))?, + ); + + Ok(client) + } + + async fn fetch_username(&self, access_token: &str) -> Result { + let response = reqwest::Client::new() + .get(&self.info_url) + .header("Authorization", format!("Bearer {}", access_token)) + .send() + .await + .map_err(|e| AuthError::NetworkError(e.to_string()))? + .json::() + .await + .map_err(|e| AuthError::NetworkError(e.to_string()))?; + + response["data"]["username"] + .as_str() + .ok_or_else(|| AuthError::NetworkError("Missing username field".to_string())) + .map(String::from) + } + + fn save_token_data( + &mut self, + username: &str, + token: &StandardTokenResponse, + ) -> Result<(), TokenStoreError> { + let access_token = token.access_token().secret().to_string(); + let refresh_token = token + .refresh_token() + .ok_or(TokenStoreError::RefreshTokenNotFound)? + .secret() + .to_string(); + + let expiration_time = token + .expires_in() + .unwrap_or(Duration::from_secs(7200)) + .as_secs() + + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + self.token_store + .save_oauth2_token(username, &access_token, &refresh_token, expiration_time) + } +} + +struct OAuth2Scopes { + read_scopes: Vec<&'static str>, + write_scopes: Vec<&'static str>, + other_scopes: Vec<&'static str>, +} + +impl OAuth2Scopes { + fn all() -> Vec { + let scopes = Self { + read_scopes: vec![ + "block.read", + "bookmark.read", + "dm.read", + "follows.read", + "like.read", + "list.read", + "mute.read", + "space.read", + "tweet.read", + "timeline.read", + "users.read", + ], + write_scopes: vec![ + "block.write", + "bookmark.write", + "dm.write", + "follows.write", + "like.write", + "list.write", + "mute.write", + "tweet.write", + "tweet.moderate.write", + "timeline.write", + "media.write", + ], + other_scopes: vec!["offline.access"], + }; + scopes.to_oauth_scopes() + } + + fn to_oauth_scopes(self) -> Vec { + self.read_scopes + .into_iter() + .chain(self.write_scopes) + .chain(self.other_scopes) + .map(|s| Scope::new(s.to_string())) + .collect() + } } // OAuth 1.0 helper functions diff --git a/src/auth/token_store.rs b/src/auth/token_store.rs index 7f7592f..dc81cf7 100644 --- a/src/auth/token_store.rs +++ b/src/auth/token_store.rs @@ -11,12 +11,19 @@ pub struct OAuth1Token { pub(crate) consumer_secret: String, } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct OAuth2Token { + pub(crate) access_token: String, + pub(crate) refresh_token: String, + pub(crate) expiration_time: u64, +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub enum Token { #[serde(rename = "bearer")] Bearer(String), // Bearer token #[serde(rename = "oauth2")] - OAuth2(String), // access_token + OAuth2(OAuth2Token), // access_token #[serde(rename = "oauth1")] OAuth1(OAuth1Token), } @@ -29,6 +36,8 @@ pub enum TokenStoreError { JSONDeserializationError, #[error("IO error")] IOError, + #[error("Refresh token not found")] + RefreshTokenNotFound, } #[derive(Debug, Serialize, Deserialize)] @@ -78,9 +87,17 @@ impl TokenStore { &mut self, username: &str, token: &str, + refresh_token: &str, + expiration_time: u64, ) -> Result<(), TokenStoreError> { - self.oauth2_tokens - .insert(username.to_string(), Token::OAuth2(token.to_string())); + self.oauth2_tokens.insert( + username.to_string(), + Token::OAuth2(OAuth2Token { + access_token: token.to_string(), + refresh_token: refresh_token.to_string(), + expiration_time, + }), + ); self.save_to_file() } diff --git a/src/main.rs b/src/main.rs index 9670686..24abcc0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -110,25 +110,25 @@ async fn main() -> Result<(), Error> { cli.username.as_deref(), cli.verbose, ) - .await { - Ok(res) => res, - Err(e) => match e { - Error::ApiError(e) => { - println!("{}", serde_json::to_string_pretty(&e)?); - std::process::exit(1) - }, - Error::HttpError(e) => { - println!("{}", e); - std::process::exit(1) - }, - _ => { - println!("{}", e); - std::process::exit(1) - } + .await + { + Ok(res) => res, + Err(e) => match e { + Error::ApiError(e) => { + println!("{}", serde_json::to_string_pretty(&e)?); + std::process::exit(1) + } + Error::HttpError(e) => { + println!("{}", e); + std::process::exit(1) + } + _ => { + println!("{}", e); + std::process::exit(1) } - }; + }, + }; - // Pretty print the response println!("{}", serde_json::to_string_pretty(&response)?);