Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rate_limit): implement user rate limiting in tabby-webserver #3484

Merged
merged 9 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ logkit = "0.3"
async-openai = "0.20"
tracing-test = "0.2"
clap = "4.3.0"
ratelimit = "0.10"

[workspace.dependencies.uuid]
version = "1.3.3"
Expand Down
2 changes: 1 addition & 1 deletion crates/http-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ tabby-common = { path = "../tabby-common" }
tabby-inference = { path = "../tabby-inference" }
ollama-api-bindings = { path = "../ollama-api-bindings" }
async-openai.workspace = true
ratelimit = "0.10"
ratelimit.workspace = true
tokio.workspace = true
tracing.workspace = true

Expand Down
2 changes: 2 additions & 0 deletions ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ cron = "0.12.1"
async-stream.workspace = true
logkit.workspace = true
async-openai.workspace = true
ratelimit.workspace = true
cached.workspace = true

[dev-dependencies]
assert_matches.workspace = true
Expand Down
1 change: 1 addition & 0 deletions ee/tabby-webserver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod hub;
mod jwt;
mod oauth;
mod path;
mod rate_limit;
mod routes;
mod service;
mod webserver;
Expand Down
61 changes: 61 additions & 0 deletions ee/tabby-webserver/src/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::time::Duration;

use cached::{Cached, TimedCache};
use tokio::sync::Mutex;

pub struct UserRateLimiter {
/// Mapping from user ID to rate limiter.
rate_limiters: Mutex<TimedCache<String, ratelimit::Ratelimiter>>,
}

static USER_REQUEST_LIMIT_PER_MINUTE: u64 = 30;

impl Default for UserRateLimiter {
fn default() -> Self {
Self {
// User rate limiter is hardcoded to 200 requests per minute, thus the timespan is 60 seconds.
wsxiaoys marked this conversation as resolved.
Show resolved Hide resolved
rate_limiters: Mutex::new(TimedCache::with_lifespan(60)),
}
}
}

impl UserRateLimiter {
pub async fn is_allowed(&self, user_id: &str) -> bool {
let mut rate_limiters = self.rate_limiters.lock().await;
let rate_limiter = rate_limiters.cache_get_or_set_with(user_id.to_string(), || {
// Create a new rate limiter for this user.
ratelimit::Ratelimiter::builder(USER_REQUEST_LIMIT_PER_MINUTE, Duration::from_secs(60))
.max_tokens(USER_REQUEST_LIMIT_PER_MINUTE * 2)
.initial_available(USER_REQUEST_LIMIT_PER_MINUTE)
.build()
.expect("Failed to create rate limiter")
});
if let Err(_sleep) = rate_limiter.try_wait() {
// If the rate limiter is full, we return false.
false
} else {
// If the rate limiter is not full, we return true.
true
}
}
}

#[cfg(test)]
mod tests {

use super::*;

#[tokio::test]
async fn test_user_rate_limiter() {
let user_id = "test_user";
let rate_limiter = UserRateLimiter::default();

// Test that the first 200 requests are allowed
wsxiaoys marked this conversation as resolved.
Show resolved Hide resolved
for _ in 0..USER_REQUEST_LIMIT_PER_MINUTE {
assert!(rate_limiter.is_allowed(user_id).await);
}

// Test that the 201st request is not allowed
wsxiaoys marked this conversation as resolved.
Show resolved Hide resolved
assert!(!rate_limiter.is_allowed(user_id).await);
}
}
13 changes: 13 additions & 0 deletions ee/tabby-webserver/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
use self::{
analytic::new_analytic_service, email::new_email_service, license::new_license_service,
};
use crate::rate_limit::UserRateLimiter;
struct ServerContext {
db_conn: DbConn,
mail: Arc<dyn EmailService>,
Expand All @@ -83,6 +84,8 @@
code: Arc<dyn CodeSearch>,

setting: Arc<dyn SettingService>,

user_rate_limiter: UserRateLimiter,
}

impl ServerContext {
Expand Down Expand Up @@ -153,6 +156,7 @@
user_group,
access_policy,
db_conn,
user_rate_limiter: UserRateLimiter::default(),

Check warning on line 159 in ee/tabby-webserver/src/service/mod.rs

View check run for this annotation

Codecov / codecov/patch

ee/tabby-webserver/src/service/mod.rs#L159

Added line #L159 was not covered by tests
}
}

Expand Down Expand Up @@ -223,6 +227,15 @@
}

if let Some(user) = user {
// Apply rate limiting when `user` is not none.
if !self.user_rate_limiter.is_allowed(&user).await {
return axum::response::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::empty())
.unwrap()
.into_response();
}

Check warning on line 238 in ee/tabby-webserver/src/service/mod.rs

View check run for this annotation

Codecov / codecov/patch

ee/tabby-webserver/src/service/mod.rs#L231-L238

Added lines #L231 - L238 were not covered by tests
request.headers_mut().append(
HeaderName::from_static(USER_HEADER_FIELD_NAME),
HeaderValue::from_str(&user).expect("User must be valid header"),
Expand Down
Loading