diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 641ef48a5403bc..866482f6820c6b 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -18,6 +18,7 @@ use axum::{ Extension, Json, Router, TypedHeader, }; use chrono::{DateTime, Duration, Utc}; +use collections::HashMap; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; use futures::{Stream, StreamExt as _}; use http_client::IsahcHttpClient; @@ -41,7 +42,8 @@ pub struct LlmState { pub db: Arc, pub http_client: IsahcHttpClient, pub clickhouse_client: Option, - active_user_count: RwLock, ActiveUserCount)>>, + active_user_count_by_model: + RwLock, ActiveUserCount)>>, } const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); @@ -69,9 +71,6 @@ impl LlmState { .build() .context("failed to construct http client")?; - let initial_active_user_count = - Some((Utc::now(), db.get_active_user_count(Utc::now()).await?)); - let this = Self { executor, db, @@ -80,25 +79,34 @@ impl LlmState { .clickhouse_url .as_ref() .and_then(|_| build_clickhouse_client(&config).log_err()), - active_user_count: RwLock::new(initial_active_user_count), + active_user_count_by_model: RwLock::new(HashMap::default()), config, }; Ok(Arc::new(this)) } - pub async fn get_active_user_count(&self) -> Result { + pub async fn get_active_user_count( + &self, + provider: LanguageModelProvider, + model: &str, + ) -> Result { let now = Utc::now(); - if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() { - if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION { - return Ok(*count); + { + let active_user_count_by_model = self.active_user_count_by_model.read().await; + if let Some((last_updated, count)) = + active_user_count_by_model.get(&(provider, model.to_string())) + { + if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION { + return Ok(*count); + } } } - let mut cache = self.active_user_count.write().await; - let new_count = self.db.get_active_user_count(now).await?; - *cache = Some((now, new_count)); + let mut cache = self.active_user_count_by_model.write().await; + let new_count = self.db.get_active_user_count(provider, model, now).await?; + cache.insert((provider, model.to_string()), (now, new_count)); Ok(new_count) } } @@ -419,7 +427,7 @@ async fn check_usage_limit( ) .await?; - let active_users = state.get_active_user_count().await?; + let active_users = state.get_active_user_count(provider, model_name).await?; let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1); let users_in_recent_days = active_users.users_in_recent_days.max(1); diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index fbffca1c8939d6..5ea0c6bce252a0 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -343,15 +343,27 @@ impl LlmDatabase { .await } - pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result { + /// Returns the active user count for the specified model. + pub async fn get_active_user_count( + &self, + provider: LanguageModelProvider, + model_name: &str, + now: DateTimeUtc, + ) -> Result { self.transaction(|tx| async move { let minute_since = now - Duration::minutes(5); let day_since = now - Duration::days(5); + let model = self + .models + .get(&(provider, model_name.to_string())) + .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?; + let users_in_recent_minutes = usage::Entity::find() .filter( - usage::Column::Timestamp - .gte(minute_since.naive_utc()) + usage::Column::ModelId + .eq(model.id) + .and(usage::Column::Timestamp.gte(minute_since.naive_utc())) .and(usage::Column::IsStaff.eq(false)), ) .select_only() @@ -362,8 +374,9 @@ impl LlmDatabase { let users_in_recent_days = usage::Entity::find() .filter( - usage::Column::Timestamp - .gte(day_since.naive_utc()) + usage::Column::ModelId + .eq(model.id) + .and(usage::Column::Timestamp.gte(day_since.naive_utc())) .and(usage::Column::IsStaff.eq(false)), ) .select_only() diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 50cb81934e11b4..7c49552d0255dd 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -302,10 +302,7 @@ async fn handle_liveness_probe( } if let Some(llm_state) = llm_state { - llm_state - .db - .get_active_user_count(chrono::Utc::now()) - .await?; + llm_state.db.list_providers().await?; } Ok("ok".to_string())