Skip to content

Commit

Permalink
Allow customization of the model used for tool calling (#15479)
Browse files Browse the repository at this point in the history
We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <[email protected]>
  • Loading branch information
as-cii and nathansobo authored Jul 30, 2024
1 parent 1bfea9d commit 99bc90a
Show file tree
Hide file tree
Showing 32 changed files with 478 additions and 691 deletions.
28 changes: 1 addition & 27 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
"crates/completion",
"crates/copilot",
"crates/db",
"crates/dev_server_projects",
Expand Down Expand Up @@ -190,7 +189,6 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
completion = { path = "crates/completion" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" }
Expand Down
19 changes: 18 additions & 1 deletion crates/anthropic/src/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ pub enum Model {
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku,
#[serde(rename = "custom")]
Custom { name: String, max_tokens: usize },
Custom {
name: String,
max_tokens: usize,
/// Override this model with a different Anthropic model for tool calls.
tool_override: Option<String>,
},
}

impl Model {
Expand Down Expand Up @@ -68,6 +73,18 @@ impl Model {
Self::Custom { max_tokens, .. } => *max_tokens,
}
}

pub fn tool_model_id(&self) -> &str {
if let Self::Custom {
tool_override: Some(tool_override),
..
} = self
{
tool_override
} else {
self.id()
}
}
}

pub async fn complete(
Expand Down
2 changes: 0 additions & 2 deletions crates/assistant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
completion.workspace = true
editor.workspace = true
fs.workspace = true
futures.workspace = true
Expand Down Expand Up @@ -77,7 +76,6 @@ workspace.workspace = true
picker.workspace = true

[dev-dependencies]
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
Expand Down
23 changes: 5 additions & 18 deletions crates/assistant/src/assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
use completion::LanguageModelCompletionProvider;
pub use context::*;
pub use context_store::*;
use fs::Fs;
Expand Down Expand Up @@ -192,7 +191,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {

context_store::init(&client);
prompt_library::init(cx);
init_completion_provider(cx);
init_language_model_settings(cx);
assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx);
Expand All @@ -217,8 +216,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach();
}

fn init_completion_provider(cx: &mut AppContext) {
completion::init(cx);
fn init_language_model_settings(cx: &mut AppContext) {
update_active_language_model_from_settings(cx);

cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
Expand All @@ -233,20 +231,9 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone());

let Some(provider) = LanguageModelRegistry::global(cx)
.read(cx)
.provider(&provider_name)
else {
return;
};

let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
completion_provider.set_active_model(model, cx);
});
}
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_active_model(&provider_name, &model_id, cx);
});
}

fn register_slash_commands(cx: &mut AppContext) {
Expand Down
34 changes: 20 additions & 14 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use client::proto;
use collections::{BTreeSet, HashMap, HashSet};
use completion::LanguageModelCompletionProvider;
use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{
Expand All @@ -43,7 +42,7 @@ use language::{
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
ToOffset,
};
use language_model::{LanguageModelProviderId, Role};
use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role};
use multi_buffer::MultiBufferRow;
use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate};
Expand Down Expand Up @@ -392,9 +391,9 @@ impl AssistantPanel {
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
cx.subscribe(&context_store, Self::handle_context_store_event),
cx.observe(
&LanguageModelCompletionProvider::global(cx),
|this, _, cx| {
cx.subscribe(
&LanguageModelRegistry::global(cx),
|this, _, _: &language_model::ActiveModelChanged, cx| {
this.completion_provider_changed(cx);
},
),
Expand Down Expand Up @@ -560,7 +559,7 @@ impl AssistantPanel {
})
}

let Some(new_provider_id) = LanguageModelCompletionProvider::read_global(cx)
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|p| p.id())
else {
Expand Down Expand Up @@ -599,7 +598,7 @@ impl AssistantPanel {
}

fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
if !provider.is_authenticated(cx) {
return Some(provider.authentication_prompt(cx));
}
Expand Down Expand Up @@ -904,9 +903,9 @@ impl AssistantPanel {
}

fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
LanguageModelCompletionProvider::read_global(cx)
.reset_credentials(cx)
.detach_and_log_err(cx);
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
provider.reset_credentials(cx).detach_and_log_err(cx);
}
}

fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
Expand Down Expand Up @@ -1041,11 +1040,18 @@ impl AssistantPanel {
}

fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
}

fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(
Task::ready(Err(anyhow!("no active language model provider"))),
|provider| provider.authenticate(cx),
)
}

fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
Expand Down Expand Up @@ -2707,7 +2713,7 @@ impl ContextEditorToolbarItem {
}

fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let context = &self
.active_context_editor
.as_ref()?
Expand Down Expand Up @@ -2779,7 +2785,7 @@ impl Render for ContextEditorToolbarItem {
.whitespace_nowrap()
.child(
Label::new(
LanguageModelCompletionProvider::read_global(cx)
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
Expand Down
24 changes: 12 additions & 12 deletions crates/assistant/src/assistant_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: AssistantDefaultModel,
pub default_model: LanguageModelSelection,
pub using_outdated_settings_version: bool,
}

Expand Down Expand Up @@ -198,25 +198,25 @@ impl AssistantSettingsContent {
.clone()
.and_then(|provider| match provider {
AssistantProviderContentV1::ZedDotDev { default_model } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "zed.dev".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::OpenAi { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "openai".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "anthropic".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Ollama { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "ollama".to_string(),
model: model.id().to_string(),
})
Expand All @@ -231,7 +231,7 @@ impl AssistantSettingsContent {
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
default_model: Some(AssistantDefaultModel {
default_model: Some(LanguageModelSelection {
provider: "openai".to_string(),
model: settings
.default_open_ai_model
Expand Down Expand Up @@ -325,7 +325,7 @@ impl AssistantSettingsContent {
_ => {}
},
VersionedAssistantSettingsContent::V2(settings) => {
settings.default_model = Some(AssistantDefaultModel { provider, model });
settings.default_model = Some(LanguageModelSelection { provider, model });
}
},
AssistantSettingsContent::Legacy(settings) => {
Expand Down Expand Up @@ -382,11 +382,11 @@ pub struct AssistantSettingsContentV2 {
/// Default: 320
default_height: Option<f32>,
/// The default model to use when creating new contexts.
default_model: Option<AssistantDefaultModel>,
default_model: Option<LanguageModelSelection>,
}

#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AssistantDefaultModel {
pub struct LanguageModelSelection {
#[schemars(schema_with = "providers_schema")]
pub provider: String,
pub model: String,
Expand All @@ -407,7 +407,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
.into()
}

impl Default for AssistantDefaultModel {
impl Default for LanguageModelSelection {
fn default() -> Self {
Self {
provider: "openai".to_string(),
Expand Down Expand Up @@ -542,7 +542,7 @@ mod tests {
assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version);
assert_eq!(
AssistantSettings::get_global(cx).default_model,
AssistantDefaultModel {
LanguageModelSelection {
provider: "openai".into(),
model: "gpt-4o".into(),
}
Expand All @@ -555,7 +555,7 @@ mod tests {
|settings, _| {
*settings = AssistantSettingsContent::Versioned(
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
default_model: Some(AssistantDefaultModel {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
Expand Down
Loading

0 comments on commit 99bc90a

Please sign in to comment.