Skip to content

Commit

Permalink
Add logic for closed beta LLM models (#16482)
Browse files Browse the repository at this point in the history
Release Notes:

- N/A

---------

Co-authored-by: Marshall <[email protected]>
  • Loading branch information
maxbrunsfeld and maxdeviant authored Aug 19, 2024
1 parent 41fc6d0 commit b5bd8a5
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 47 deletions.
5 changes: 5 additions & 0 deletions crates/collab/k8s/collab.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ spec:
secretKeyRef:
name: anthropic
key: staff_api_key
- name: LLM_CLOSED_BETA_MODEL_NAME
valueFrom:
secretKeyRef:
name: llm-closed-beta
key: model_name
- name: GOOGLE_AI_API_KEY
valueFrom:
secretKeyRef:
Expand Down
2 changes: 2 additions & 0 deletions crates/collab/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ pub struct Config {
pub google_ai_api_key: Option<Arc<str>>,
pub anthropic_api_key: Option<Arc<str>>,
pub anthropic_staff_api_key: Option<Arc<str>>,
pub llm_closed_beta_model_name: Option<Arc<str>>,
pub qwen2_7b_api_key: Option<Arc<str>>,
pub qwen2_7b_api_url: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
Expand Down Expand Up @@ -219,6 +220,7 @@ impl Config {
google_ai_api_key: None,
anthropic_api_key: None,
anthropic_staff_api_key: None,
llm_closed_beta_model_name: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,
Expand Down
27 changes: 20 additions & 7 deletions crates/collab/src/llm/authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ pub fn authorize_access_to_language_model(
model: &str,
) -> Result<()> {
authorize_access_for_country(config, country_code, provider)?;
authorize_access_to_model(claims, provider, model)?;
authorize_access_to_model(config, claims, provider, model)?;
Ok(())
}

fn authorize_access_to_model(
config: &Config,
claims: &LlmTokenClaims,
provider: LanguageModelProvider,
model: &str,
Expand All @@ -25,13 +26,25 @@ fn authorize_access_to_model(
return Ok(());
}

match (provider, model) {
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
_ => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
))?,
match provider {
LanguageModelProvider::Anthropic => {
if model == "claude-3-5-sonnet" {
return Ok(());
}

if claims.has_llm_closed_beta_feature_flag
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
{
return Ok(());
}
}
_ => {}
}

Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
))
}

fn authorize_access_for_country(
Expand Down
7 changes: 4 additions & 3 deletions crates/collab/src/llm/db/queries/usages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,13 @@ impl LlmDatabase {
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];

let mut results = Vec::new();
for (provider, model) in self.models.keys().cloned() {
for ((provider, model_name), model) in self.models.iter() {
let mut usages = usage::Entity::find()
.filter(
usage::Column::Timestamp
.gte(past_minute.naive_utc())
.and(usage::Column::IsStaff.eq(false))
.and(usage::Column::ModelId.eq(model.id))
.and(
usage::Column::MeasureId
.eq(requests_per_minute)
Expand Down Expand Up @@ -125,8 +126,8 @@ impl LlmDatabase {
}

results.push(ApplicationWideUsage {
provider,
model,
provider: *provider,
model: model_name.clone(),
requests_this_minute,
tokens_this_minute,
})
Expand Down
4 changes: 4 additions & 0 deletions crates/collab/src/llm/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pub struct LlmTokenClaims {
#[serde(default)]
pub github_user_login: Option<String>,
pub is_staff: bool,
#[serde(default)]
pub has_llm_closed_beta_feature_flag: bool,
pub plan: rpc::proto::Plan,
}

Expand All @@ -30,6 +32,7 @@ impl LlmTokenClaims {
user_id: UserId,
github_user_login: String,
is_staff: bool,
has_llm_closed_beta_feature_flag: bool,
plan: rpc::proto::Plan,
config: &Config,
) -> Result<String> {
Expand All @@ -46,6 +49,7 @@ impl LlmTokenClaims {
user_id: user_id.to_proto(),
github_user_login: Some(github_user_login),
is_staff,
has_llm_closed_beta_feature_flag,
plan,
};

Expand Down
6 changes: 5 additions & 1 deletion crates/collab/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4918,7 +4918,10 @@ async fn get_llm_api_token(
let db = session.db().await;

let flags = db.get_user_flags(session.user_id()).await?;
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");

if !session.is_staff() && !has_language_models_feature_flag {
Err(anyhow!("permission denied"))?
}

Expand All @@ -4943,6 +4946,7 @@ async fn get_llm_api_token(
user.id,
user.github_login.clone(),
session.is_staff(),
has_llm_closed_beta_feature_flag,
session.current_plan(db).await?,
&session.app_state.config,
)?;
Expand Down
1 change: 1 addition & 0 deletions crates/collab/src/tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ impl TestServer {
google_ai_api_key: None,
anthropic_api_key: None,
anthropic_staff_api_key: None,
llm_closed_beta_model_name: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,
Expand Down
5 changes: 5 additions & 0 deletions crates/feature_flags/src/feature_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ impl FeatureFlag for LanguageModels {
const NAME: &'static str = "language-models";
}

pub struct LlmClosedBeta {}
impl FeatureFlag for LlmClosedBeta {
const NAME: &'static str = "llm-closed-beta";
}

pub struct ZedPro {}
impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
Expand Down
94 changes: 58 additions & 36 deletions crates/language_model/src/provider/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use anthropic::AnthropicError;
use anyhow::{anyhow, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, ZedPro};
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
use futures::{
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
TryStreamExt as _,
Expand All @@ -26,7 +26,10 @@ use smol::{
io::{AsyncReadExt, BufReader},
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::{future, sync::Arc};
use std::{
future,
sync::{Arc, LazyLock},
};
use strum::IntoEnumIterator;
use ui::prelude::*;

Expand All @@ -37,6 +40,18 @@ use super::anthropic::count_anthropic_tokens;
pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "Zed";

const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");

fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
.map(|json| serde_json::from_str(json).unwrap())
.unwrap_or(Vec::new())
});
ADDITIONAL_MODELS.as_slice()
}

#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
pub available_models: Vec<AvailableModel>,
Expand Down Expand Up @@ -200,47 +215,54 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
for model in ZedModel::iter() {
models.insert(model.id().to_string(), CloudModel::Zed(model));
}

// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.zed_dot_dev
.available_models
{
let model = match model.provider {
AvailableProvider::Anthropic => {
CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
cache_configuration: model.cache_configuration.as_ref().map(|config| {
anthropic::AnthropicModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
should_speculate: config.should_speculate,
min_total_token: config.min_total_token,
}
}),
max_output_tokens: model.max_output_tokens,
})
}
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
}),
};
models.insert(model.id().to_string(), model.clone());
}
} else {
models.insert(
anthropic::Model::Claude3_5Sonnet.id().to_string(),
CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
);
}

let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
zed_cloud_provider_additional_models()
} else {
&[]
};

// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.zed_dot_dev
.available_models
.iter()
.chain(llm_closed_beta_models)
.cloned()
{
let model = match model.provider {
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
cache_configuration: model.cache_configuration.as_ref().map(|config| {
anthropic::AnthropicModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
should_speculate: config.should_speculate,
min_total_token: config.min_total_token,
}
}),
max_output_tokens: model.max_output_tokens,
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
}),
};
models.insert(model.id().to_string(), model.clone());
}

models
.into_values()
.map(|model| {
Expand Down

0 comments on commit b5bd8a5

Please sign in to comment.