diff --git a/Cargo.lock b/Cargo.lock index 011eb5f01e0..424d65f6c69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5282,6 +5282,7 @@ dependencies = [ "axum", "axum-extra", "bincode", + "cached", "chrono", "cron", "fs_extra", @@ -5300,6 +5301,7 @@ dependencies = [ "octocrab", "pin-project", "querystring", + "ratelimit", "reqwest", "rust-embed", "serde", diff --git a/Cargo.toml b/Cargo.toml index 7441bb8de7d..605e5269bba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ logkit = "0.3" async-openai = "0.20" tracing-test = "0.2" clap = "4.3.0" +ratelimit = "0.10" [workspace.dependencies.uuid] version = "1.3.3" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 9ac1541ba88..7037c5958b4 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -18,7 +18,7 @@ tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } ollama-api-bindings = { path = "../ollama-api-bindings" } async-openai.workspace = true -ratelimit = "0.10" +ratelimit.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index cde4f85797d..b0c192b9da0 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -54,6 +54,8 @@ cron = "0.12.1" async-stream.workspace = true logkit.workspace = true async-openai.workspace = true +ratelimit.workspace = true +cached.workspace = true [dev-dependencies] assert_matches.workspace = true diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 2e3352fb416..06560ebd4d1 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -4,6 +4,7 @@ mod hub; mod jwt; mod oauth; mod path; +mod rate_limit; mod routes; mod service; mod webserver; diff --git a/ee/tabby-webserver/src/rate_limit.rs b/ee/tabby-webserver/src/rate_limit.rs new file mode 100644 index 00000000000..c85bd5f7c8c --- /dev/null +++ b/ee/tabby-webserver/src/rate_limit.rs @@ -0,0 +1,61 @@ +use std::time::Duration; + +use cached::{Cached, TimedCache}; +use tokio::sync::Mutex; + +pub struct UserRateLimiter { + /// Mapping from user ID to rate limiter. + rate_limiters: Mutex>, +} + +static USER_REQUEST_LIMIT_PER_MINUTE: u64 = 30; + +impl Default for UserRateLimiter { + fn default() -> Self { + Self { + // User rate limiter is hardcoded to 30 requests per minute, thus the timespan is 60 seconds. + rate_limiters: Mutex::new(TimedCache::with_lifespan(60)), + } + } +} + +impl UserRateLimiter { + pub async fn is_allowed(&self, user_id: &str) -> bool { + let mut rate_limiters = self.rate_limiters.lock().await; + let rate_limiter = rate_limiters.cache_get_or_set_with(user_id.to_string(), || { + // Create a new rate limiter for this user. + ratelimit::Ratelimiter::builder(USER_REQUEST_LIMIT_PER_MINUTE, Duration::from_secs(60)) + .max_tokens(USER_REQUEST_LIMIT_PER_MINUTE * 2) + .initial_available(USER_REQUEST_LIMIT_PER_MINUTE) + .build() + .expect("Failed to create rate limiter") + }); + if let Err(_sleep) = rate_limiter.try_wait() { + // If the rate limiter is full, we return false. + false + } else { + // If the rate limiter is not full, we return true. + true + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[tokio::test] + async fn test_user_rate_limiter() { + let user_id = "test_user"; + let rate_limiter = UserRateLimiter::default(); + + // Test that the first `USER_REQUEST_LIMIT_PER_MINUTE` requests are allowed + for _ in 0..USER_REQUEST_LIMIT_PER_MINUTE { + assert!(rate_limiter.is_allowed(user_id).await); + } + + // Test that the 201st request is not allowed + assert!(!rate_limiter.is_allowed(user_id).await); + } +} diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 0cabb751ee2..387c637ad37 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -61,6 +61,7 @@ use tabby_schema::{ use self::{ analytic::new_analytic_service, email::new_email_service, license::new_license_service, }; +use crate::rate_limit::UserRateLimiter; struct ServerContext { db_conn: DbConn, mail: Arc, @@ -83,6 +84,8 @@ struct ServerContext { code: Arc, setting: Arc, + + user_rate_limiter: UserRateLimiter, } impl ServerContext { @@ -153,6 +156,7 @@ impl ServerContext { user_group, access_policy, db_conn, + user_rate_limiter: UserRateLimiter::default(), } } @@ -223,6 +227,15 @@ impl WorkerService for ServerContext { } if let Some(user) = user { + // Apply rate limiting when `user` is not none. + if !self.user_rate_limiter.is_allowed(&user).await { + return axum::response::Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .body(Body::empty()) + .unwrap() + .into_response(); + } + request.headers_mut().append( HeaderName::from_static(USER_HEADER_FIELD_NAME), HeaderValue::from_str(&user).expect("User must be valid header"),