Skip to content

Commit

Permalink
Add support for interacting with Claude in the assistant panel (#11798)
Browse files Browse the repository at this point in the history
Release Notes:

- Added support for interacting with Claude in the assistant panel. You
can enable it by adding the following to your `settings.json`:

    ```json
    "assistant": {
        "version": "1",
        "provider": {
            "name": "anthropic"
        }
    }
    ```
  • Loading branch information
as-cii authored May 14, 2024
1 parent 019d988 commit 5944caa
Show file tree
Hide file tree
Showing 12 changed files with 446 additions and 21 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions crates/anthropic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ edition = "2021"
publish = false
license = "AGPL-3.0-or-later"

[features]
default = []
schemars = ["dep:schemars"]

[lints]
workspace = true

Expand All @@ -15,6 +19,8 @@ path = "src/anthropic.rs"
anyhow.workspace = true
futures.workspace = true
http.workspace = true
isahc.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true

Expand Down
34 changes: 25 additions & 9 deletions crates/anthropic/src/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, sync::Arc};
use std::{convert::TryFrom, time::Duration};

pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";

#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum Model {
#[default]
#[serde(rename = "claude-3-opus-20240229")]
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet-20240229")]
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku-20240307")]
#[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
Claude3Haiku,
}

Expand All @@ -28,6 +32,14 @@ impl Model {
}
}

pub fn id(&self) -> &'static str {
match self {
Model::Claude3Opus => "claude-3-opus-20240229",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
Model::Claude3Haiku => "claude-3-opus-20240307",
}
}

pub fn display_name(&self) -> &'static str {
match self {
Self::Claude3Opus => "Claude 3 Opus",
Expand Down Expand Up @@ -141,20 +153,24 @@ pub enum TextDelta {
}

pub async fn stream_completion(
client: Arc<dyn HttpClient>,
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let uri = format!("{api_url}/v1/messages");
let request = HttpRequest::builder()
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
.header("Anthropic-Beta", "messages-2023-12-15")
.header("Anthropic-Beta", "tools-2024-04-04")
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json")
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
}
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Expand Down
1 change: 1 addition & 0 deletions crates/assistant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ doctest = false

[dependencies]
anyhow.workspace = true
anthropic = { workspace = true, features = ["schemars"] }
chrono.workspace = true
client.workspace = true
collections.workspace = true
Expand Down
7 changes: 6 additions & 1 deletion crates/assistant/src/assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod saved_conversation;
mod streaming_diff;

pub use assistant_panel::AssistantPanel;
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub(crate) use completion_provider::*;
Expand Down Expand Up @@ -72,6 +72,7 @@ impl Display for Role {
pub enum LanguageModel {
ZedDotDev(ZedDotDevModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
}

impl Default for LanguageModel {
Expand All @@ -84,27 +85,31 @@ impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
}
}

pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::ZedDotDev(model) => model.display_name().into(),
}
}

pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::ZedDotDev(model) => model.max_token_count(),
}
}

pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::ZedDotDev(model) => model.id(),
}
}
Expand Down
5 changes: 5 additions & 0 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,11 @@ impl AssistantPanel {
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
}),
LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
}),
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
Expand Down
16 changes: 15 additions & 1 deletion crates/assistant/src/assistant_settings.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt;

pub use anthropic::Model as AnthropicModel;
use gpui::Pixels;
pub use open_ai::Model as OpenAiModel;
use schemars::{
Expand Down Expand Up @@ -161,6 +162,15 @@ pub enum AssistantProvider {
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
},
#[serde(rename = "anthropic")]
Anthropic {
#[serde(default)]
default_model: AnthropicModel,
#[serde(default = "anthropic_api_url")]
api_url: String,
#[serde(default)]
low_speed_timeout_in_seconds: Option<u64>,
},
}

impl Default for AssistantProvider {
Expand All @@ -172,7 +182,11 @@ impl Default for AssistantProvider {
}

fn open_ai_url() -> String {
"https://api.openai.com/v1".into()
open_ai::OPEN_AI_API_URL.to_string()
}

fn anthropic_api_url() -> String {
anthropic::ANTHROPIC_API_URL.to_string()
}

#[derive(Default, Debug, Deserialize, Serialize)]
Expand Down
61 changes: 57 additions & 4 deletions crates/assistant/src/completion_provider.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod anthropic;
#[cfg(test)]
mod fake;
mod open_ai;
mod zed;

pub use anthropic::*;
#[cfg(test)]
pub use fake::*;
pub use open_ai::*;
Expand Down Expand Up @@ -42,6 +44,17 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
AssistantProvider::Anthropic {
default_model,
api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
};
cx.set_global(provider);

Expand All @@ -64,13 +77,28 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
);
}
(
CompletionProvider::Anthropic(provider),
AssistantProvider::Anthropic {
default_model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
provider.update(
default_model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
);
}
(
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { default_model },
) => {
provider.update(default_model.clone(), settings_version);
}
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
(_, AssistantProvider::ZedDotDev { default_model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
client.clone(),
Expand All @@ -79,7 +107,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
));
}
(
CompletionProvider::ZedDotDev(_),
_,
AssistantProvider::OpenAi {
default_model,
api_url,
Expand All @@ -94,8 +122,22 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
));
}
#[cfg(test)]
(CompletionProvider::Fake(_), _) => unimplemented!(),
(
_,
AssistantProvider::Anthropic {
default_model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
));
}
}
})
})
Expand All @@ -104,6 +146,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {

pub enum CompletionProvider {
OpenAi(OpenAiCompletionProvider),
Anthropic(AnthropicCompletionProvider),
ZedDotDev(ZedDotDevCompletionProvider),
#[cfg(test)]
Fake(FakeCompletionProvider),
Expand All @@ -119,6 +162,7 @@ impl CompletionProvider {
pub fn settings_version(&self) -> usize {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
CompletionProvider::Anthropic(provider) => provider.settings_version(),
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
Expand All @@ -128,6 +172,7 @@ impl CompletionProvider {
pub fn is_authenticated(&self) -> bool {
match self {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
#[cfg(test)]
CompletionProvider::Fake(_) => true,
Expand All @@ -137,6 +182,7 @@ impl CompletionProvider {
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
Expand All @@ -146,6 +192,7 @@ impl CompletionProvider {
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
match self {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
Expand All @@ -155,6 +202,7 @@ impl CompletionProvider {
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
Expand All @@ -164,6 +212,9 @@ impl CompletionProvider {
pub fn default_model(&self) -> LanguageModel {
match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
CompletionProvider::Anthropic(provider) => {
LanguageModel::Anthropic(provider.default_model())
}
CompletionProvider::ZedDotDev(provider) => {
LanguageModel::ZedDotDev(provider.default_model())
}
Expand All @@ -179,6 +230,7 @@ impl CompletionProvider {
) -> BoxFuture<'static, Result<usize>> {
match self {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
Expand All @@ -191,6 +243,7 @@ impl CompletionProvider {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match self {
CompletionProvider::OpenAi(provider) => provider.complete(request),
CompletionProvider::Anthropic(provider) => provider.complete(request),
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
#[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(),
Expand Down
Loading

0 comments on commit 5944caa

Please sign in to comment.