Skip to content

Commit

Permalink
Add Qwen2-7B to the list of zed.dev models (#15649)
Browse files Browse the repository at this point in the history
Release Notes:

- N/A

---------

Co-authored-by: Nathan <[email protected]>
  • Loading branch information
as-cii and nathansobo authored Aug 1, 2024
1 parent 60127f2 commit 21816d1
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 2 deletions.
10 changes: 10 additions & 0 deletions crates/collab/k8s/collab.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ spec:
secretKeyRef:
name: google-ai
key: api_key
- name: QWEN2_7B_API_KEY
valueFrom:
secretKeyRef:
name: hugging-face
key: api_key
- name: QWEN2_7B_API_URL
valueFrom:
secretKeyRef:
name: hugging-face
key: qwen2_api_url
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:
Expand Down
2 changes: 2 additions & 0 deletions crates/collab/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ pub struct Config {
pub openai_api_key: Option<Arc<str>>,
pub google_ai_api_key: Option<Arc<str>>,
pub anthropic_api_key: Option<Arc<str>>,
pub qwen2_7b_api_key: Option<Arc<str>>,
pub qwen2_7b_api_url: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
Expand Down
24 changes: 24 additions & 0 deletions crates/collab/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4706,6 +4706,30 @@ async fn stream_complete_with_language_model(
})?;
}
}
Some(proto::LanguageModelProvider::Zed) => {
let api_key = config
.qwen2_7b_api_key
.as_ref()
.context("no Qwen2-7B API key configured on the server")?;
let api_url = config
.qwen2_7b_api_url
.as_ref()
.context("no Qwen2-7B URL configured on the server")?;
let mut events = open_ai::stream_completion(
session.http_client.as_ref(),
&api_url,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
None => return Err(anyhow!("unknown provider"))?,
}

Expand Down
2 changes: 2 additions & 0 deletions crates/collab/src/tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ impl TestServer {
stripe_api_key: None,
stripe_price_id: None,
supermaven_admin_api_key: None,
qwen2_7b_api_key: None,
qwen2_7b_api_url: None,
},
})
}
Expand Down
31 changes: 31 additions & 0 deletions crates/language_model/src/model/cloud_model.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "provider", rename_all = "lowercase")]
pub enum CloudModel {
Anthropic(anthropic::Model),
OpenAi(open_ai::Model),
Google(google_ai::Model),
Zed(ZedModel),
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum ZedModel {
#[serde(rename = "qwen2-7b-instruct")]
Qwen2_7bInstruct,
}

impl ZedModel {
pub fn id(&self) -> &str {
match self {
ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
}
}

pub fn display_name(&self) -> &str {
match self {
ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct",
}
}

pub fn max_token_count(&self) -> usize {
match self {
ZedModel::Qwen2_7bInstruct => 8192,
}
}
}

impl Default for CloudModel {
Expand All @@ -21,6 +49,7 @@ impl CloudModel {
CloudModel::Anthropic(model) => model.id(),
CloudModel::OpenAi(model) => model.id(),
CloudModel::Google(model) => model.id(),
CloudModel::Zed(model) => model.id(),
}
}

Expand All @@ -29,6 +58,7 @@ impl CloudModel {
CloudModel::Anthropic(model) => model.display_name(),
CloudModel::OpenAi(model) => model.display_name(),
CloudModel::Google(model) => model.display_name(),
CloudModel::Zed(model) => model.display_name(),
}
}

Expand All @@ -37,6 +67,7 @@ impl CloudModel {
CloudModel::Anthropic(model) => model.max_token_count(),
CloudModel::OpenAi(model) => model.max_token_count(),
CloudModel::Google(model) => model.max_token_count(),
CloudModel::Zed(model) => model.max_token_count(),
}
}
}
29 changes: 28 additions & 1 deletion crates/language_model/src/provider/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anyhow::{anyhow, Context as _, Result};
use client::{Client, UserStore};
Expand Down Expand Up @@ -146,6 +146,9 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
models.insert(model.id().to_string(), CloudModel::Google(model));
}
}
for model in ZedModel::iter() {
models.insert(model.id().to_string(), CloudModel::Zed(model));
}

// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
Expand Down Expand Up @@ -263,6 +266,9 @@ impl LanguageModel for CloudLanguageModel {
}
.boxed()
}
CloudModel::Zed(_) => {
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
}
}
}

Expand Down Expand Up @@ -323,6 +329,24 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Zed(model) => {
let client = self.client.clone();
let mut request = request.into_open_ai(model.id().into());
request.max_tokens = Some(4000);
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Zed as i32,
request,
})
.await?;
Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
}

Expand Down Expand Up @@ -382,6 +406,9 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
CloudModel::Zed(_) => {
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/language_model/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl LanguageModelRequest {
stream: true,
stop: self.stop,
temperature: self.temperature,
max_tokens: None,
tools: Vec::new(),
tool_choice: None,
}
Expand Down
14 changes: 13 additions & 1 deletion crates/open_ai/src/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ pub struct Request {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -216,6 +218,13 @@ pub struct ChoiceDelta {
pub finish_reason: Option<String>,
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum ResponseStreamResult {
Ok(ResponseStreamEvent),
Err { error: String },
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseStreamEvent {
pub created: u32,
Expand Down Expand Up @@ -256,7 +265,10 @@ pub async fn stream_completion(
None
} else {
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
Ok(ResponseStreamResult::Err { error }) => {
Some(Err(anyhow!(error)))
}
Err(error) => Some(Err(anyhow!(error))),
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/proto/proto/zed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,7 @@ enum LanguageModelProvider {
Anthropic = 0;
OpenAI = 1;
Google = 2;
Zed = 3;
}

message GetCachedEmbeddings {
Expand Down

0 comments on commit 21816d1

Please sign in to comment.