Skip to content

Commit

Permalink
Add OAuth2 token validation and refresh logic in ApiClient and Auth m…
Browse files Browse the repository at this point in the history
…odules.
  • Loading branch information
santiagomed committed Dec 21, 2024
1 parent e1939e2 commit bbff701
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 102 deletions.
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug",
"program": "${workspaceFolder}/<executable file>",
"args": [],
"cwd": "${workspaceFolder}"
}
]
}
86 changes: 60 additions & 26 deletions src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use reqwest::RequestBuilder;
use reqwest::{Client, Method};
use serde_json::Value;
use std::cell::RefCell;

use std::time::{SystemTime, UNIX_EPOCH};
pub struct ApiClient {
url: String,
client: Client,
Expand All @@ -33,30 +33,57 @@ impl ApiClient {
self
}

/// Validate the OAuth2 token and refresh it if it is expired
async fn validate_and_refresh_oauth2_token(
&self,
auth: &RefCell<Auth>,
token: Token,
username: Option<&str>,
) -> Result<String, Error> {
match token {
Token::OAuth2(token) => {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();

if current_time > token.expiration_time {
let new_token = auth.borrow_mut().oauth2_refresh_token(username).await?;
Ok(format!("Bearer {}", new_token))
} else {
Ok(format!("Bearer {}", token.access_token))
}
}
_ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)),
}
}

/// Get the OAuth2 token from the token store, validate it and refresh it if it is expired
async fn get_oauth2_token(
&self,
auth: &RefCell<Auth>,
username: Option<&str>,
) -> Result<String, Error> {
match username {
Some(username) => {
let token = auth.borrow_mut().oauth2(Some(username)).await?;
Ok(format!("Bearer {}", token))
let token = {
let mut auth_ref = auth.borrow_mut();
match username {
Some(username) => auth_ref.get_token_store().get_oauth2_token(username),
None => auth_ref.get_token_store().get_first_oauth2_token(),
}
};
match token {
Some(token) => {
self.validate_and_refresh_oauth2_token(auth, token, username)
.await
}
None => {
if let Some(token) = auth.borrow_mut().get_token_store().get_first_oauth2_token() {
match token {
Token::OAuth2(token) => Ok(format!("Bearer {}", token)),
_ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)),
}
} else {
let token = auth.borrow_mut().oauth2(None).await?;
Ok(format!("Bearer {}", token))
}
let token = auth.borrow_mut().oauth2(username).await?;
Ok(format!("Bearer {}", token))
}
}
}

/// Get the auth header for the request
async fn get_auth_header(
&self,
method: &str,
Expand All @@ -69,7 +96,7 @@ impl ApiClient {
None => return Ok("".to_string()),
};

match auth_type.as_deref() {
match auth_type {
Some("app") => {
if let Some(token) = auth.borrow().bearer_token() {
Ok(format!("Bearer {}", token))
Expand All @@ -90,16 +117,16 @@ impl ApiClient {
let token = {
let mut auth_ref = auth.borrow_mut();
if let Some(username) = username {
// Username passed, we need to get the token for the specific username
auth_ref.get_token_store().get_oauth2_token(username)
} else {
// No username passed, we need to get the first oauth2 token
auth_ref.get_token_store().get_first_oauth2_token()
}
};
if let Some(token) = token {
match token {
Token::OAuth2(token) => Ok(format!("Bearer {}", token)),
_ => Err(Error::AuthError(AuthError::WrongTokenFoundInStore)),
}
self.validate_and_refresh_oauth2_token(auth, token, username)
.await
} else {
let oauth1_result = {
let auth_ref = auth.borrow();
Expand All @@ -121,7 +148,7 @@ impl ApiClient {
}
}

pub async fn build_request(
pub async fn build_request(
&self,
method: &str,
endpoint: &str,
Expand Down Expand Up @@ -188,8 +215,8 @@ pub async fn build_request(
let response = request_builder.send().await?;

if verbose {
println!("Request: {:#?}", req);
println!("Response: {:#?}", response)
println!("{:#?}", req);
println!("{:#?}", response)
}

let status = response.status();
Expand All @@ -201,7 +228,7 @@ pub async fn build_request(
} else {
Ok(res)
}
},
}
Err(_) => {
let status = status.to_string();
Err(Error::ApiError(serde_json::json!({
Expand All @@ -226,15 +253,22 @@ mod tests {

fn mock_auth() -> Auth {
let config = Config::from_env();
let auth = Auth::new(config)
.with_token_store(TokenStore::from_file_path(".xurl_test".into()));
let auth =
Auth::new(config).with_token_store(TokenStore::from_file_path(".xurl_test".into()));
auth
}

fn setup_tests_with_mock_oauth2_token() -> Auth {
let mut auth = mock_auth();
let token_store = auth.get_token_store();
token_store.save_oauth2_token("test", "fake_token").unwrap();
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 7200;
token_store
.save_oauth2_token("test", "fake_token", "fake_refresh_token", current_time)
.unwrap();

auth
}
Expand Down
Loading

0 comments on commit bbff701

Please sign in to comment.