From 019d98898e1f0ee6d81f7e587e802d8ebad56a4a Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 14 May 2024 13:55:47 +0200 Subject: [PATCH] Add support for gpt-4o when using zed.dev as the model provider (#11794) Release Notes: - N/A --- crates/assistant/src/assistant_panel.rs | 5 +++-- crates/assistant/src/assistant_settings.rs | 9 +++++++-- crates/assistant/src/completion_provider/open_ai.rs | 11 ++++++++++- crates/assistant/src/completion_provider/zed.rs | 1 + 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index d090c96357c02d..c47cd022d86c39 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -803,12 +803,13 @@ impl AssistantPanel { LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model { ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4, ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo, - ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus, + ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Gpt4Omni, + ZedDotDevModel::Gpt4Omni => ZedDotDevModel::Claude3Opus, ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet, ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku, ZedDotDevModel::Claude3Haiku => { match CompletionProvider::global(cx).default_model() { - LanguageModel::ZedDotDev(custom) => custom, + LanguageModel::ZedDotDev(custom @ ZedDotDevModel::Custom(_)) => custom, _ => ZedDotDevModel::Gpt3Point5Turbo, } } diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index f2dc55e17dc90e..4a5dd399524312 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -16,8 +16,9 @@ use settings::{Settings, SettingsSources}; pub enum ZedDotDevModel { Gpt3Point5Turbo, Gpt4, - #[default] Gpt4Turbo, + #[default] + Gpt4Omni, Claude3Opus, Claude3Sonnet, Claude3Haiku, @@ -55,6 +56,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel { "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo), "gpt-4" => Ok(ZedDotDevModel::Gpt4), "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo), + "gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni), _ => Ok(ZedDotDevModel::Custom(value.to_owned())), } } @@ -74,6 +76,7 @@ impl JsonSchema for ZedDotDevModel { "gpt-3.5-turbo".to_owned(), "gpt-4".to_owned(), "gpt-4-turbo-preview".to_owned(), + "gpt-4o".to_owned(), ]; Schema::Object(SchemaObject { instance_type: Some(InstanceType::String.into()), @@ -100,6 +103,7 @@ impl ZedDotDevModel { Self::Gpt3Point5Turbo => "gpt-3.5-turbo", Self::Gpt4 => "gpt-4", Self::Gpt4Turbo => "gpt-4-turbo-preview", + Self::Gpt4Omni => "gpt-4o", Self::Claude3Opus => "claude-3-opus", Self::Claude3Sonnet => "claude-3-sonnet", Self::Claude3Haiku => "claude-3-haiku", @@ -112,6 +116,7 @@ impl ZedDotDevModel { Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", Self::Gpt4 => "GPT 4", Self::Gpt4Turbo => "GPT 4 Turbo", + Self::Gpt4Omni => "GPT 4 Omni", Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Sonnet => "Claude 3 Sonnet", Self::Claude3Haiku => "Claude 3 Haiku", @@ -123,7 +128,7 @@ impl ZedDotDevModel { match self { Self::Gpt3Point5Turbo => 2048, Self::Gpt4 => 4096, - Self::Gpt4Turbo => 128000, + Self::Gpt4Turbo | Self::Gpt4Omni => 128000, Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000, Self::Custom(_) => 4096, // TODO: Make this configurable } diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs index 6f6d2d3fe405f6..347c8318e505bf 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -1,3 +1,4 @@ +use crate::assistant_settings::ZedDotDevModel; use crate::{ assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, }; @@ -202,7 +203,15 @@ pub fn count_open_ai_tokens( }) .collect::>(); - tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages) + match request.model { + LanguageModel::OpenAi(OpenAiModel::FourOmni) + | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) => { + // Tiktoken doesn't yet support gpt-4o, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) + } + _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages), + } }) .boxed() } diff --git a/crates/assistant/src/completion_provider/zed.rs b/crates/assistant/src/completion_provider/zed.rs index ed84f1f7c69903..fc34c7fa6a51f9 100644 --- a/crates/assistant/src/completion_provider/zed.rs +++ b/crates/assistant/src/completion_provider/zed.rs @@ -81,6 +81,7 @@ impl ZedDotDevCompletionProvider { LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(), LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4) | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo) + | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => { count_open_ai_tokens(request, cx.background_executor()) }