Skip to content

Commit

Permalink
Expose context server settings to extensions (#20555)
Browse files Browse the repository at this point in the history
This PR exposes context server settings to extensions.

Extensions can use `ContextServerSettings::for_project` to get the
context server settings for the current project.

The `experimental.context_servers` setting has been removed and replaced
with the `context_servers` setting (which is now an object instead of an
array).

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <[email protected]>
  • Loading branch information
maxdeviant and maxbrunsfeld authored Nov 12, 2024
1 parent 0a9c78a commit 3ebb64e
Show file tree
Hide file tree
Showing 17 changed files with 239 additions and 122 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 2 additions & 11 deletions assets/settings/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -1182,15 +1182,6 @@
// }
// ]
"ssh_connections": [],
// Configures the Context Server Protocol binaries
//
// Examples:
// {
// "id": "server-1",
// "executable": "/path",
// "args": ['arg1", "args2"]
// }
"experimental.context_servers": {
"servers": []
}
// Configures context servers for use in the Assistant.
"context_servers": {}
}
73 changes: 38 additions & 35 deletions crates/assistant/src/context_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,49 +145,52 @@ impl ContextStore {
project: project.clone(),
prompt_builder,
};
this.handle_project_changed(project, cx);
this.handle_project_changed(project.clone(), cx);
this.synchronize_contexts(cx);
this.register_context_server_handlers(cx);

// TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
// In order to register the context servers when the extension is loaded, we're periodically looping to
// see if there are context servers to register.
//
// I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
//
// We should find a more elegant way to do this.
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
cx.spawn(|context_store, mut cx| async move {
loop {
let mut servers_to_register = Vec::new();
for (_id, factory) in
context_server_factory_registry.context_server_factories()
{
if let Some(server) = factory(&cx).await.log_err() {
servers_to_register.push(server);
if project.read(cx).is_local() {
// TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
// In order to register the context servers when the extension is loaded, we're periodically looping to
// see if there are context servers to register.
//
// I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
//
// We should find a more elegant way to do this.
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
cx.spawn(|context_store, mut cx| async move {
loop {
let mut servers_to_register = Vec::new();
for (_id, factory) in
context_server_factory_registry.context_server_factories()
{
if let Some(server) = factory(project.clone(), &cx).await.log_err()
{
servers_to_register.push(server);
}
}
}

let Some(_) = context_store
.update(&mut cx, |this, cx| {
this.context_server_manager.update(cx, |this, cx| {
for server in servers_to_register {
this.add_server(server, cx).detach_and_log_err(cx);
}
let Some(_) = context_store
.update(&mut cx, |this, cx| {
this.context_server_manager.update(cx, |this, cx| {
for server in servers_to_register {
this.add_server(server, cx).detach_and_log_err(cx);
}
})
})
})
.log_err()
else {
break;
};
.log_err()
else {
break;
};

smol::Timer::after(Duration::from_millis(100)).await;
}
smol::Timer::after(Duration::from_millis(100)).await;
}

anyhow::Ok(())
})
.detach_and_log_err(cx);
anyhow::Ok(())
})
.detach_and_log_err(cx);
}

this
})?;
Expand Down
1 change: 1 addition & 0 deletions crates/context_servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ gpui.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion crates/context_servers/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct Client {

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct ContextServerId(pub String);
pub struct ContextServerId(pub Arc<str>);

fn is_null_value<T: Serialize>(value: &T) -> bool {
if let Ok(Value::Null) = serde_json::to_value(value) {
Expand Down
57 changes: 34 additions & 23 deletions crates/context_servers/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;

use anyhow::Result;
use anyhow::{bail, Result};
use async_trait::async_trait;
use collections::{HashMap, HashSet};
use futures::{Future, FutureExt};
Expand All @@ -36,19 +36,25 @@ use crate::{

#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ContextServerSettings {
pub servers: Vec<ServerConfig>,
#[serde(default)]
pub context_servers: HashMap<Arc<str>, ServerConfig>,
}

#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
pub struct ServerConfig {
pub id: String,
pub executable: String,
pub command: Option<ServerCommand>,
pub settings: Option<serde_json::Value>,
}

#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ServerCommand {
pub path: String,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}

impl Settings for ContextServerSettings {
const KEY: Option<&'static str> = Some("experimental.context_servers");
const KEY: Option<&'static str> = None;

type FileContent = Self;

Expand Down Expand Up @@ -79,9 +85,9 @@ pub struct NativeContextServer {
}

impl NativeContextServer {
pub fn new(config: Arc<ServerConfig>) -> Self {
pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
Self {
id: config.id.clone().into(),
id,
config,
client: RwLock::new(None),
}
Expand All @@ -107,13 +113,16 @@ impl ContextServer for NativeContextServer {
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
async move {
log::info!("starting context server {}", self.config.id,);
log::info!("starting context server {}", self.id);
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
};
let client = Client::new(
client::ContextServerId(self.config.id.clone()),
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&self.config.executable).to_path_buf(),
args: self.config.args.clone(),
env: self.config.env.clone(),
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?;
Expand All @@ -127,7 +136,7 @@ impl ContextServer for NativeContextServer {

log::debug!(
"context server {} initialized: {:?}",
self.config.id,
self.id,
initialized_protocol.initialize,
);

Expand Down Expand Up @@ -242,7 +251,7 @@ impl ContextServerManager {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
let config = server.config();
let new_server = Arc::new(NativeContextServer::new(config));
let new_server = Arc::new(NativeContextServer::new(id.clone(), config));
new_server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
Expand Down Expand Up @@ -270,15 +279,15 @@ impl ContextServerManager {
.collect::<HashMap<_, _>>();

let new_servers = settings
.servers
.context_servers
.iter()
.map(|config| (config.id.clone(), config.clone()))
.map(|(id, config)| (id.clone(), config.clone()))
.collect::<HashMap<_, _>>();

let servers_to_add = new_servers
.values()
.filter(|config| !current_servers.contains_key(config.id.as_str()))
.cloned()
.iter()
.filter(|(id, _)| !current_servers.contains_key(id.as_ref()))
.map(|(id, config)| (id.clone(), config.clone()))
.collect::<Vec<_>>();

let servers_to_remove = current_servers
Expand All @@ -288,9 +297,11 @@ impl ContextServerManager {
.collect::<Vec<_>>();

log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
let server = Arc::new(NativeContextServer::new(Arc::new(config)));
self.add_server(server, cx).detach_and_log_err(cx);
for (id, config) in servers_to_add {
if config.command.is_some() {
let server = Arc::new(NativeContextServer::new(id, Arc::new(config)));
self.add_server(server, cx).detach_and_log_err(cx);
}
}

for id in servers_to_remove {
Expand Down
12 changes: 8 additions & 4 deletions crates/context_servers/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ use std::sync::Arc;

use anyhow::Result;
use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, ReadGlobal};
use gpui::{Global, Task};
use gpui::{AppContext, AsyncAppContext, Global, Model, ReadGlobal, Task};
use parking_lot::RwLock;
use project::Project;

use crate::ContextServer;

pub type ContextServerFactory =
Arc<dyn Fn(&AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>> + Send + Sync + 'static>;
pub type ContextServerFactory = Arc<
dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>>
+ Send
+ Sync
+ 'static,
>;

#[derive(Default)]
struct GlobalContextServerFactoryRegistry(Arc<ContextServerFactoryRegistry>);
Expand Down
15 changes: 11 additions & 4 deletions crates/extension_api/src/extension_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub use wit::{
SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput, SlashCommandOutputSection,
},
CodeLabel, CodeLabelSpan, CodeLabelSpanLiteral, Command, DownloadedFileType, EnvVars,
KeyValueStore, LanguageServerInstallationStatus, Range, Worktree,
KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree,
};

// Undocumented WIT re-exports.
Expand Down Expand Up @@ -130,7 +130,11 @@ pub trait Extension: Send + Sync {
}

/// Returns the command used to start a context server.
fn context_server_command(&mut self, _context_server_id: &ContextServerId) -> Result<Command> {
fn context_server_command(
&mut self,
_context_server_id: &ContextServerId,
_project: &Project,
) -> Result<Command> {
Err("`context_server_command` not implemented".to_string())
}

Expand Down Expand Up @@ -275,9 +279,12 @@ impl wit::Guest for Component {
extension().run_slash_command(command, args, worktree)
}

fn context_server_command(context_server_id: String) -> Result<wit::Command> {
fn context_server_command(
context_server_id: String,
project: &Project,
) -> Result<wit::Command> {
let context_server_id = ContextServerId(context_server_id);
extension().context_server_command(&context_server_id)
extension().context_server_command(&context_server_id, project)
}

fn suggest_docs_packages(provider: String) -> Result<Vec<String>, String> {
Expand Down
54 changes: 38 additions & 16 deletions crates/extension_api/src/settings.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,56 @@
//! Provides access to Zed settings.
#[path = "../wit/since_v0.1.0/settings.rs"]
#[path = "../wit/since_v0.2.0/settings.rs"]
mod types;

use crate::{wit, Result, SettingsLocation, Worktree};
use crate::{wit, Project, Result, SettingsLocation, Worktree};
use serde_json;
pub use types::*;

impl LanguageSettings {
/// Returns the [`LanguageSettings`] for the given language.
pub fn for_worktree(language: Option<&str>, worktree: &Worktree) -> Result<Self> {
let location = SettingsLocation {
worktree_id: worktree.id(),
path: worktree.root_path(),
};
let settings_json = wit::get_settings(Some(&location), "language", language)?;
let settings: Self = serde_json::from_str(&settings_json).map_err(|err| err.to_string())?;
Ok(settings)
get_settings("language", language, Some(worktree.id()))
}
}

impl LspSettings {
/// Returns the [`LspSettings`] for the given language server.
pub fn for_worktree(language_server_name: &str, worktree: &Worktree) -> Result<Self> {
let location = SettingsLocation {
worktree_id: worktree.id(),
path: worktree.root_path(),
};
let settings_json = wit::get_settings(Some(&location), "lsp", Some(language_server_name))?;
let settings: Self = serde_json::from_str(&settings_json).map_err(|err| err.to_string())?;
Ok(settings)
get_settings("lsp", Some(language_server_name), Some(worktree.id()))
}
}

impl ContextServerSettings {
/// Returns the [`ContextServerSettings`] for the given context server.
pub fn for_project(context_server_id: &str, project: &Project) -> Result<Self> {
let global_setting: Self = get_settings("context_servers", Some(context_server_id), None)?;

for worktree_id in project.worktree_ids() {
let settings = get_settings(
"context_servers",
Some(context_server_id),
Some(worktree_id),
)?;
if settings != global_setting {
return Ok(settings);
}
}

Ok(global_setting)
}
}

fn get_settings<T: serde::de::DeserializeOwned>(
settings_type: &str,
settings_name: Option<&str>,
worktree_id: Option<u64>,
) -> Result<T> {
let location = worktree_id.map(|worktree_id| SettingsLocation {
worktree_id,
path: String::new(),
});
let settings_json = wit::get_settings(location.as_ref(), settings_type, settings_name)?;
let settings: T = serde_json::from_str(&settings_json).map_err(|err| err.to_string())?;
Ok(settings)
}
Loading

0 comments on commit 3ebb64e

Please sign in to comment.