From 16c230e570fa15cb183110602da0fcb36653e4e3 Mon Sep 17 00:00:00 2001
From: Timshel <Timshel@users.noreply.github.com>
Date: Fri, 10 Jan 2025 17:38:40 +0100
Subject: [PATCH] Add wrapper type OIDCCode OIDCState OIDCIdentifier

---
 src/api/identity.rs        |  18 +++---
 src/db/models/sso_nonce.rs |  10 ++--
 src/db/models/user.rs      |   3 +-
 src/sso.rs                 | 110 ++++++++++++++++++++++++++++++++-----
 4 files changed, 111 insertions(+), 30 deletions(-)

diff --git a/src/api/identity.rs b/src/api/identity.rs
index 2ead4db304..4754152824 100644
--- a/src/api/identity.rs
+++ b/src/api/identity.rs
@@ -23,7 +23,9 @@ use crate::{
     auth::{AuthMethod, ClientHeaders, ClientIp},
     db::{models::*, DbConn},
     error::MapResult,
-    mail, sso, util, CONFIG,
+    mail, sso,
+    sso::{OIDCCode, OIDCState},
+    util, CONFIG,
 };
 
 pub fn routes() -> Vec<Route> {
@@ -968,7 +970,7 @@ fn prevalidate() -> JsonResult {
 }
 
 #[get("/connect/oidc-signin?<code>&<state>", rank = 1)]
-async fn oidcsignin(code: String, state: String, conn: DbConn) -> ApiResult<Redirect> {
+async fn oidcsignin(code: OIDCCode, state: String, conn: DbConn) -> ApiResult<Redirect> {
     oidcsignin_redirect(
         state,
         |decoded_state| sso::OIDCCodeWrapper::Ok {
@@ -1005,16 +1007,10 @@ async fn oidcsignin_error(
 // iss and scope parameters are needed for redirection to work on IOS.
 async fn oidcsignin_redirect(
     base64_state: String,
-    wrapper: impl FnOnce(String) -> sso::OIDCCodeWrapper,
+    wrapper: impl FnOnce(OIDCState) -> sso::OIDCCodeWrapper,
     conn: &DbConn,
 ) -> ApiResult<Redirect> {
-    let state = match data_encoding::BASE64.decode(base64_state.as_bytes()) {
-        Ok(vec) => match String::from_utf8(vec) {
-            Ok(valid) => valid,
-            Err(_) => err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding")),
-        },
-        Err(_) => err!(format!("Failed to decode {base64_state} using base64")),
-    };
+    let state = sso::deocde_state(base64_state)?;
     let code = sso::encode_code_claims(wrapper(state.clone()));
 
     let nonce = match SsoNonce::find(&state, conn).await {
@@ -1050,7 +1046,7 @@ struct AuthorizeData {
     response_type: Option<String>,
     #[allow(unused)]
     scope: Option<String>,
-    state: String,
+    state: OIDCState,
     #[allow(unused)]
     code_challenge: Option<String>,
     #[allow(unused)]
diff --git a/src/db/models/sso_nonce.rs b/src/db/models/sso_nonce.rs
index 881f075bfe..2246a43741 100644
--- a/src/db/models/sso_nonce.rs
+++ b/src/db/models/sso_nonce.rs
@@ -3,14 +3,14 @@ use chrono::{NaiveDateTime, Utc};
 use crate::api::EmptyResult;
 use crate::db::{DbConn, DbPool};
 use crate::error::MapResult;
-use crate::sso::NONCE_EXPIRATION;
+use crate::sso::{OIDCState, NONCE_EXPIRATION};
 
 db_object! {
     #[derive(Identifiable, Queryable, Insertable)]
     #[diesel(table_name = sso_nonce)]
     #[diesel(primary_key(state))]
     pub struct SsoNonce {
-        pub state: String,
+        pub state: OIDCState,
         pub nonce: String,
         pub verifier: Option<String>,
         pub redirect_uri: String,
@@ -20,7 +20,7 @@ db_object! {
 
 /// Local methods
 impl SsoNonce {
-    pub fn new(state: String, nonce: String, verifier: Option<String>, redirect_uri: String) -> Self {
+    pub fn new(state: OIDCState, nonce: String, verifier: Option<String>, redirect_uri: String) -> Self {
         let now = Utc::now().naive_utc();
 
         SsoNonce {
@@ -53,7 +53,7 @@ impl SsoNonce {
         }
     }
 
-    pub async fn delete(state: &str, conn: &mut DbConn) -> EmptyResult {
+    pub async fn delete(state: &OIDCState, conn: &mut DbConn) -> EmptyResult {
         db_run! { conn: {
             diesel::delete(sso_nonce::table.filter(sso_nonce::state.eq(state)))
                 .execute(conn)
@@ -61,7 +61,7 @@ impl SsoNonce {
         }}
     }
 
-    pub async fn find(state: &str, conn: &DbConn) -> Option<Self> {
+    pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
         let oldest = Utc::now().naive_utc() - *NONCE_EXPIRATION;
         db_run! { conn: {
             sso_nonce::table
diff --git a/src/db/models/user.rs b/src/db/models/user.rs
index 95b28d99f9..4a3b25fb2f 100644
--- a/src/db/models/user.rs
+++ b/src/db/models/user.rs
@@ -10,6 +10,7 @@ use crate::{
     crypto,
     db::DbConn,
     error::MapResult,
+    sso::OIDCIdentifier,
     util::{format_date, get_uuid, retry},
     CONFIG,
 };
@@ -77,7 +78,7 @@ db_object! {
     #[diesel(primary_key(user_uuid))]
     pub struct SsoUser {
         pub user_uuid: UserId,
-        pub identifier: String,
+        pub identifier: OIDCIdentifier,
     }
 }
 
diff --git a/src/sso.rs b/src/sso.rs
index 2fbb52060c..19531b27f4 100644
--- a/src/sso.rs
+++ b/src/sso.rs
@@ -1,4 +1,5 @@
 use chrono::Utc;
+use derive_more::{AsRef, Deref, Display, From};
 use regex::Regex;
 use std::borrow::Cow;
 use std::time::Duration;
@@ -27,7 +28,7 @@ use crate::{
     CONFIG,
 };
 
-static AC_CACHE: Lazy<Cache<String, AuthenticatedUser>> =
+static AC_CACHE: Lazy<Cache<OIDCState, AuthenticatedUser>> =
     Lazy::new(|| Cache::builder().max_capacity(1000).time_to_live(Duration::from_secs(10 * 60)).build());
 
 static CLIENT_CACHE_KEY: Lazy<String> = Lazy::new(|| "sso-client".to_string());
@@ -54,6 +55,46 @@ impl<'a, AD: AuthDisplay, P: AuthPrompt, RT: ResponseType> AuthorizationRequestE
     }
 }
 
+#[derive(
+    Clone,
+    Debug,
+    Default,
+    DieselNewType,
+    FromForm,
+    PartialEq,
+    Eq,
+    Hash,
+    Serialize,
+    Deserialize,
+    AsRef,
+    Deref,
+    Display,
+    From,
+)]
+#[deref(forward)]
+#[from(forward)]
+pub struct OIDCCode(String);
+
+#[derive(
+    Clone,
+    Debug,
+    Default,
+    DieselNewType,
+    FromForm,
+    PartialEq,
+    Eq,
+    Hash,
+    Serialize,
+    Deserialize,
+    AsRef,
+    Deref,
+    Display,
+    From,
+)]
+#[deref(forward)]
+#[from(forward)]
+pub struct OIDCState(String);
+
 #[derive(Debug, Serialize, Deserialize)]
 struct SsoTokenJwtClaims {
     // Not before
@@ -81,11 +122,11 @@ pub fn encode_ssotoken_claims() -> String {
 #[derive(Debug, Serialize, Deserialize)]
 pub enum OIDCCodeWrapper {
     Ok {
-        state: String,
-        code: String,
+        state: OIDCState,
+        code: OIDCCode,
     },
     Error {
-        state: String,
+        state: OIDCState,
         error: String,
         error_description: Option<String>,
     },
@@ -208,12 +249,29 @@ impl CoreClientExt for CoreClient {
     }
 }
 
+pub fn deocde_state(base64_state: String) -> ApiResult<OIDCState> {
+    let state = match data_encoding::BASE64.decode(base64_state.as_bytes()) {
+        Ok(vec) => match String::from_utf8(vec) {
+            Ok(valid) => OIDCState(valid),
+            Err(_) => err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding")),
+        },
+        Err(_) => err!(format!("Failed to decode {base64_state} using base64")),
+    };
+
+    Ok(state)
+}
+
 // The `nonce` allow to protect against replay attacks
 // The `state` is encoded using base64 to ensure no issue with providers (It contains the Organization identifier).
 // redirect_uri from: https://github.com/bitwarden/server/blob/main/src/Identity/IdentityServer/ApiClient.cs
-pub async fn authorize_url(state: String, client_id: &str, raw_redirect_uri: &str, mut conn: DbConn) -> ApiResult<Url> {
+pub async fn authorize_url(
+    state: OIDCState,
+    client_id: &str,
+    raw_redirect_uri: &str,
+    mut conn: DbConn,
+) -> ApiResult<Url> {
     let scopes = CONFIG.sso_scopes_vec().into_iter().map(Scope::new);
-    let base64_state = data_encoding::BASE64.encode(state.as_bytes());
+    let base64_state = data_encoding::BASE64.encode(state.to_string().as_bytes());
 
     let redirect_uri = match client_id {
         "web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()),
@@ -254,12 +312,38 @@ pub async fn authorize_url(state: String, client_id: &str, raw_redirect_uri: &st
     Ok(auth_url)
 }
 
+#[derive(
+    Clone,
+    Debug,
+    Default,
+    DieselNewType,
+    FromForm,
+    PartialEq,
+    Eq,
+    Hash,
+    Serialize,
+    Deserialize,
+    AsRef,
+    Deref,
+    Display,
+    From,
+)]
+#[deref(forward)]
+#[from(forward)]
+pub struct OIDCIdentifier(String);
+
+impl OIDCIdentifier {
+    fn new(issuer: &str, subject: &str) -> Self {
+        OIDCIdentifier(format!("{}/{}", issuer, subject))
+    }
+}
+
 #[derive(Clone, Debug)]
 pub struct AuthenticatedUser {
     pub refresh_token: Option<String>,
     pub access_token: String,
     pub expires_in: Option<Duration>,
-    pub identifier: String,
+    pub identifier: OIDCIdentifier,
     pub email: String,
     pub email_verified: Option<bool>,
     pub user_name: Option<String>,
@@ -267,14 +351,14 @@ pub struct AuthenticatedUser {
 
 #[derive(Clone, Debug)]
 pub struct UserInformation {
-    pub state: String,
-    pub identifier: String,
+    pub state: OIDCState,
+    pub identifier: OIDCIdentifier,
     pub email: String,
     pub email_verified: Option<bool>,
     pub user_name: Option<String>,
 }
 
-async fn decode_code_claims(code: &str, conn: &mut DbConn) -> ApiResult<(String, String)> {
+async fn decode_code_claims(code: &str, conn: &mut DbConn) -> ApiResult<(OIDCCode, OIDCState)> {
     match auth::decode_jwt::<OIDCCodeClaims>(code, SSO_JWT_ISSUER.to_string()) {
         Ok(code_claims) => match code_claims.code {
             OIDCCodeWrapper::Ok {
@@ -317,7 +401,7 @@ pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult<U
         });
     }
 
-    let oidc_code = AuthorizationCode::new(code.clone());
+    let oidc_code = AuthorizationCode::new(code.to_string());
     let client = CoreClient::cached().await?;
 
     let nonce = match SsoNonce::find(&state, conn).await {
@@ -377,7 +461,7 @@ pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult<U
                 error!("Scope offline_access is present but response contain no refresh_token");
             }
 
-            let identifier = format!("{}/{}", **id_claims.issuer(), **id_claims.subject());
+            let identifier = OIDCIdentifier::new(id_claims.issuer(), id_claims.subject());
 
             let authenticated_user = AuthenticatedUser {
                 refresh_token,
@@ -404,7 +488,7 @@ pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult<U
 }
 
 // User has passed 2FA flow we can delete `nonce` and clear the cache.
-pub async fn redeem(state: &String, conn: &mut DbConn) -> ApiResult<AuthenticatedUser> {
+pub async fn redeem(state: &OIDCState, conn: &mut DbConn) -> ApiResult<AuthenticatedUser> {
     if let Err(err) = SsoNonce::delete(state, conn).await {
         error!("Failed to delete database sso_nonce using {state}: {err}")
     }