From 5944caaa9018ef1cc45f68df4fc910afeb1c2e5c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 14 May 2024 15:57:52 +0200 Subject: [PATCH] Add support for interacting with Claude in the assistant panel (#11798) Release Notes: - Added support for interacting with Claude in the assistant panel. You can enable it by adding the following to your `settings.json`: ```json "assistant": { "version": "1", "provider": { "name": "anthropic" } } ``` --- Cargo.lock | 3 + crates/anthropic/Cargo.toml | 6 + crates/anthropic/src/anthropic.rs | 34 +- crates/assistant/Cargo.toml | 1 + crates/assistant/src/assistant.rs | 7 +- crates/assistant/src/assistant_panel.rs | 5 + crates/assistant/src/assistant_settings.rs | 16 +- crates/assistant/src/completion_provider.rs | 61 +++- .../src/completion_provider/anthropic.rs | 317 ++++++++++++++++++ .../src/completion_provider/open_ai.rs | 10 +- .../assistant/src/completion_provider/zed.rs | 2 +- crates/collab/src/rpc.rs | 5 +- 12 files changed, 446 insertions(+), 21 deletions(-) create mode 100644 crates/assistant/src/completion_provider/anthropic.rs diff --git a/Cargo.lock b/Cargo.lock index 44de9e533d611..678730a1c9ce6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -225,6 +225,8 @@ dependencies = [ "anyhow", "futures 0.3.28", "http 0.1.0", + "isahc", + "schemars", "serde", "serde_json", "tokio", @@ -332,6 +334,7 @@ dependencies = [ name = "assistant" version = "0.1.0" dependencies = [ + "anthropic", "anyhow", "chrono", "client", diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 1cd9d1b953327..484a9b3e10f0e 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -5,6 +5,10 @@ edition = "2021" publish = false license = "AGPL-3.0-or-later" +[features] +default = [] +schemars = ["dep:schemars"] + [lints] workspace = true @@ -15,6 +19,8 @@ path = "src/anthropic.rs" anyhow.workspace = true futures.workspace = true http.workspace = true +isahc.workspace = true +schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index dd5d523fb40f1..d642166af84b0 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,17 +1,21 @@ use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use isahc::config::Configurable; use serde::{Deserialize, Serialize}; -use std::{convert::TryFrom, sync::Arc}; +use std::{convert::TryFrom, time::Duration}; +pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com"; + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub enum Model { #[default] - #[serde(rename = "claude-3-opus-20240229")] + #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")] Claude3Opus, - #[serde(rename = "claude-3-sonnet-20240229")] + #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")] Claude3Sonnet, - #[serde(rename = "claude-3-haiku-20240307")] + #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")] Claude3Haiku, } @@ -28,6 +32,14 @@ impl Model { } } + pub fn id(&self) -> &'static str { + match self { + Model::Claude3Opus => "claude-3-opus-20240229", + Model::Claude3Sonnet => "claude-3-sonnet-20240229", + Model::Claude3Haiku => "claude-3-opus-20240307", + } + } + pub fn display_name(&self) -> &'static str { match self { Self::Claude3Opus => "Claude 3 Opus", @@ -141,20 +153,24 @@ pub enum TextDelta { } pub async fn stream_completion( - client: Arc, + client: &dyn HttpClient, api_url: &str, api_key: &str, request: Request, + low_speed_timeout: Option, ) -> Result>> { let uri = format!("{api_url}/v1/messages"); - let request = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Anthropic-Version", "2023-06-01") - .header("Anthropic-Beta", "messages-2023-12-15") + .header("Anthropic-Beta", "tools-2024-04-04") .header("X-Api-Key", api_key) - .header("Content-Type", "application/json") - .body(AsyncBody::from(serde_json::to_string(&request)?))?; + .header("Content-Type", "application/json"); + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + } + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 9249ece8dbf38..eb9caef25ba52 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -11,6 +11,7 @@ doctest = false [dependencies] anyhow.workspace = true +anthropic = { workspace = true, features = ["schemars"] } chrono.workspace = true client.workspace = true collections.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index ae436df59ffee..949cc1741d6bb 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -7,7 +7,7 @@ mod saved_conversation; mod streaming_diff; pub use assistant_panel::AssistantPanel; -use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel}; +use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel}; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub(crate) use completion_provider::*; @@ -72,6 +72,7 @@ impl Display for Role { pub enum LanguageModel { ZedDotDev(ZedDotDevModel), OpenAi(OpenAiModel), + Anthropic(AnthropicModel), } impl Default for LanguageModel { @@ -84,6 +85,7 @@ impl LanguageModel { pub fn telemetry_id(&self) -> String { match self { LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), + LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()), } } @@ -91,6 +93,7 @@ impl LanguageModel { pub fn display_name(&self) -> String { match self { LanguageModel::OpenAi(model) => model.display_name().into(), + LanguageModel::Anthropic(model) => model.display_name().into(), LanguageModel::ZedDotDev(model) => model.display_name().into(), } } @@ -98,6 +101,7 @@ impl LanguageModel { pub fn max_token_count(&self) -> usize { match self { LanguageModel::OpenAi(model) => model.max_token_count(), + LanguageModel::Anthropic(model) => model.max_token_count(), LanguageModel::ZedDotDev(model) => model.max_token_count(), } } @@ -105,6 +109,7 @@ impl LanguageModel { pub fn id(&self) -> &str { match self { LanguageModel::OpenAi(model) => model.id(), + LanguageModel::Anthropic(model) => model.id(), LanguageModel::ZedDotDev(model) => model.id(), } } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index c47cd022d86c3..aa3ac66db08b4 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -800,6 +800,11 @@ impl AssistantPanel { open_ai::Model::FourTurbo => open_ai::Model::FourOmni, open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo, }), + LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model { + anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet, + anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku, + anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus, + }), LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model { ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4, ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo, diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 4a5dd39952431..7341e9ecbec76 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,5 +1,6 @@ use std::fmt; +pub use anthropic::Model as AnthropicModel; use gpui::Pixels; pub use open_ai::Model as OpenAiModel; use schemars::{ @@ -161,6 +162,15 @@ pub enum AssistantProvider { #[serde(default)] low_speed_timeout_in_seconds: Option, }, + #[serde(rename = "anthropic")] + Anthropic { + #[serde(default)] + default_model: AnthropicModel, + #[serde(default = "anthropic_api_url")] + api_url: String, + #[serde(default)] + low_speed_timeout_in_seconds: Option, + }, } impl Default for AssistantProvider { @@ -172,7 +182,11 @@ impl Default for AssistantProvider { } fn open_ai_url() -> String { - "https://api.openai.com/v1".into() + open_ai::OPEN_AI_API_URL.to_string() +} + +fn anthropic_api_url() -> String { + anthropic::ANTHROPIC_API_URL.to_string() } #[derive(Default, Debug, Deserialize, Serialize)] diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 534709358e263..a2c60d708d2a4 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -1,8 +1,10 @@ +mod anthropic; #[cfg(test)] mod fake; mod open_ai; mod zed; +pub use anthropic::*; #[cfg(test)] pub use fake::*; pub use open_ai::*; @@ -42,6 +44,17 @@ pub fn init(client: Arc, cx: &mut AppContext) { low_speed_timeout_in_seconds.map(Duration::from_secs), settings_version, )), + AssistantProvider::Anthropic { + default_model, + api_url, + low_speed_timeout_in_seconds, + } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new( + default_model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )), }; cx.set_global(provider); @@ -64,13 +77,28 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, ); } + ( + CompletionProvider::Anthropic(provider), + AssistantProvider::Anthropic { + default_model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + provider.update( + default_model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ); + } ( CompletionProvider::ZedDotDev(provider), AssistantProvider::ZedDotDev { default_model }, ) => { provider.update(default_model.clone(), settings_version); } - (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => { + (_, AssistantProvider::ZedDotDev { default_model }) => { *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( default_model.clone(), client.clone(), @@ -79,7 +107,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { )); } ( - CompletionProvider::ZedDotDev(_), + _, AssistantProvider::OpenAi { default_model, api_url, @@ -94,8 +122,22 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, )); } - #[cfg(test)] - (CompletionProvider::Fake(_), _) => unimplemented!(), + ( + _, + AssistantProvider::Anthropic { + default_model, + api_url, + low_speed_timeout_in_seconds, + }, + ) => { + *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new( + default_model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + )); + } } }) }) @@ -104,6 +146,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { pub enum CompletionProvider { OpenAi(OpenAiCompletionProvider), + Anthropic(AnthropicCompletionProvider), ZedDotDev(ZedDotDevCompletionProvider), #[cfg(test)] Fake(FakeCompletionProvider), @@ -119,6 +162,7 @@ impl CompletionProvider { pub fn settings_version(&self) -> usize { match self { CompletionProvider::OpenAi(provider) => provider.settings_version(), + CompletionProvider::Anthropic(provider) => provider.settings_version(), CompletionProvider::ZedDotDev(provider) => provider.settings_version(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), @@ -128,6 +172,7 @@ impl CompletionProvider { pub fn is_authenticated(&self) -> bool { match self { CompletionProvider::OpenAi(provider) => provider.is_authenticated(), + CompletionProvider::Anthropic(provider) => provider.is_authenticated(), CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(), #[cfg(test)] CompletionProvider::Fake(_) => true, @@ -137,6 +182,7 @@ impl CompletionProvider { pub fn authenticate(&self, cx: &AppContext) -> Task> { match self { CompletionProvider::OpenAi(provider) => provider.authenticate(cx), + CompletionProvider::Anthropic(provider) => provider.authenticate(cx), CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), @@ -146,6 +192,7 @@ impl CompletionProvider { pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { match self { CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), + CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), @@ -155,6 +202,7 @@ impl CompletionProvider { pub fn reset_credentials(&self, cx: &AppContext) -> Task> { match self { CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), + CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), @@ -164,6 +212,9 @@ impl CompletionProvider { pub fn default_model(&self) -> LanguageModel { match self { CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()), + CompletionProvider::Anthropic(provider) => { + LanguageModel::Anthropic(provider.default_model()) + } CompletionProvider::ZedDotDev(provider) => { LanguageModel::ZedDotDev(provider.default_model()) } @@ -179,6 +230,7 @@ impl CompletionProvider { ) -> BoxFuture<'static, Result> { match self { CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), + CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), @@ -191,6 +243,7 @@ impl CompletionProvider { ) -> BoxFuture<'static, Result>>> { match self { CompletionProvider::OpenAi(provider) => provider.complete(request), + CompletionProvider::Anthropic(provider) => provider.complete(request), CompletionProvider::ZedDotDev(provider) => provider.complete(request), #[cfg(test)] CompletionProvider::Fake(provider) => provider.complete(), diff --git a/crates/assistant/src/completion_provider/anthropic.rs b/crates/assistant/src/completion_provider/anthropic.rs new file mode 100644 index 0000000000000..d3a05c8f518d8 --- /dev/null +++ b/crates/assistant/src/completion_provider/anthropic.rs @@ -0,0 +1,317 @@ +use crate::count_open_ai_tokens; +use crate::{ + assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, + Role, +}; +use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole}; +use anyhow::{anyhow, Result}; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace}; +use http::HttpClient; +use settings::Settings; +use std::time::Duration; +use std::{env, sync::Arc}; +use theme::ThemeSettings; +use ui::prelude::*; +use util::ResultExt; + +pub struct AnthropicCompletionProvider { + api_key: Option, + api_url: String, + default_model: AnthropicModel, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, +} + +impl AnthropicCompletionProvider { + pub fn new( + default_model: AnthropicModel, + api_url: String, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + ) -> Self { + Self { + api_key: None, + api_url, + default_model, + http_client, + low_speed_timeout, + settings_version, + } + } + + pub fn update( + &mut self, + default_model: AnthropicModel, + api_url: String, + low_speed_timeout: Option, + settings_version: usize, + ) { + self.default_model = default_model; + self.api_url = api_url; + self.low_speed_timeout = low_speed_timeout; + self.settings_version = settings_version; + } + + pub fn settings_version(&self) -> usize { + self.settings_version + } + + pub fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + let api_url = self.api_url.clone(); + cx.spawn(|mut cx| async move { + let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") { + api_key + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + String::from_utf8(api_key)? + }; + cx.update_global::(|provider, _cx| { + if let CompletionProvider::Anthropic(provider) = provider { + provider.api_key = Some(api_key); + } + }) + }) + } + } + + pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + let delete_credentials = cx.delete_credentials(&self.api_url); + cx.spawn(|mut cx| async move { + delete_credentials.await.log_err(); + cx.update_global::(|provider, _cx| { + if let CompletionProvider::Anthropic(provider) = provider { + provider.api_key = None; + } + }) + }) + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) + .into() + } + + pub fn default_model(&self) -> AnthropicModel { + self.default_model.clone() + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + count_open_ai_tokens(request, cx.background_executor()) + } + + pub fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_anthropic_request(request); + + let http_client = self.http_client.clone(); + let api_key = self.api_key.clone(); + let api_url = self.api_url.clone(); + let low_speed_timeout = self.low_speed_timeout; + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let request = stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + low_speed_timeout, + ); + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(response) => match response { + anthropic::ResponseEvent::ContentBlockStart { + content_block, .. + } => match content_block { + anthropic::ContentBlock::Text { text } => Some(Ok(text)), + }, + anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => { + match delta { + anthropic::TextDelta::TextDelta { text } => Some(Ok(text)), + } + } + _ => None, + }, + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + + fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request { + let model = match request.model { + LanguageModel::Anthropic(model) => model, + _ => self.default_model(), + }; + + let mut system_message = String::new(); + let messages = request + .messages + .into_iter() + .filter_map(|message| { + match message.role { + Role::User => Some(RequestMessage { + role: AnthropicRole::User, + content: message.content, + }), + Role::Assistant => Some(RequestMessage { + role: AnthropicRole::Assistant, + content: message.content, + }), + // Anthropic's API breaks system instructions out as a separate field rather + // than having a system message role. + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + + None + } + } + }) + .collect(); + + Request { + model, + messages, + stream: true, + system: system_message, + max_tokens: 4092, + } + } +} + +struct AuthenticationPrompt { + api_key: View, + api_url: String, +} + +impl AuthenticationPrompt { + fn new(api_url: String, cx: &mut WindowContext) -> Self { + Self { + api_key: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text( + "sk-000000000000000000000000000000000000000000000000", + cx, + ); + editor + }), + api_url, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let api_key = self.api_key.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes()); + cx.spawn(|_, mut cx| async move { + write_credentials.await?; + cx.update_global::(|provider, _cx| { + if let CompletionProvider::Anthropic(provider) = provider { + provider.api_key = Some(api_key); + } + }) + }) + .detach_and_log_err(cx); + } + + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_size: rems(0.875).into(), + font_weight: FontWeight::NORMAL, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + }; + EditorElement::new( + &self.api_key, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const INSTRUCTIONS: [&str; 4] = [ + "To use the assistant panel or inline assistant, you need to add your Anthropic API key.", + "You can create an API key at: https://console.anthropic.com/settings/keys", + "", + "Paste your Anthropic API key below and hit enter to use the assistant:", + ]; + + v_flex() + .p_4() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .child( + h_flex() + .gap_2() + .child(Label::new("Click on").size(LabelSize::Small)) + .child(Icon::new(IconName::Ai).size(IconSize::XSmall)) + .child( + Label::new("in the status bar to close this panel.").size(LabelSize::Small), + ), + ) + .into_any() + } +} diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs index 347c8318e505b..b456f3e08acbc 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -151,8 +151,8 @@ impl OpenAiCompletionProvider { fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { let model = match request.model { - LanguageModel::ZedDotDev(_) => self.default_model(), LanguageModel::OpenAi(model) => model, + _ => self.default_model(), }; Request { @@ -205,8 +205,12 @@ pub fn count_open_ai_tokens( match request.model { LanguageModel::OpenAi(OpenAiModel::FourOmni) - | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) => { - // Tiktoken doesn't yet support gpt-4o, so we manually use the + | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) + | LanguageModel::Anthropic(_) + | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus) + | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet) + | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => { + // Tiktoken doesn't yet support these models, so we manually use the // same tokenizer as GPT-4. tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) } diff --git a/crates/assistant/src/completion_provider/zed.rs b/crates/assistant/src/completion_provider/zed.rs index fc34c7fa6a51f..8fa149807204d 100644 --- a/crates/assistant/src/completion_provider/zed.rs +++ b/crates/assistant/src/completion_provider/zed.rs @@ -78,7 +78,6 @@ impl ZedDotDevCompletionProvider { cx: &AppContext, ) -> BoxFuture<'static, Result> { match request.model { - LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(), LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4) | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo) | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) @@ -108,6 +107,7 @@ impl ZedDotDevCompletionProvider { } .boxed() } + _ => future::ready(Err(anyhow!("invalid model"))).boxed(), } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 78f39403b0643..29f9644028cb2 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4489,8 +4489,8 @@ async fn complete_with_anthropic( .collect(); let mut stream = anthropic::stream_completion( - session.http_client.clone(), - "https://api.anthropic.com", + session.http_client.as_ref(), + anthropic::ANTHROPIC_API_URL, &api_key, anthropic::Request { model, @@ -4499,6 +4499,7 @@ async fn complete_with_anthropic( system: system_message, max_tokens: 4092, }, + None, ) .await?;