From 0c1096d9ca9e7c38f9bd03e99d5dc20d016e61da Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Fri, 10 Jan 2025 00:27:58 -0700 Subject: [PATCH] assistant2: Implement refresh of context on message editor send --- crates/assistant2/src/context.rs | 63 +++++-- crates/assistant2/src/context_store.rs | 225 ++++++++++++++++++------ crates/assistant2/src/message_editor.rs | 62 ++++--- 3 files changed, 254 insertions(+), 96 deletions(-) diff --git a/crates/assistant2/src/context.rs b/crates/assistant2/src/context.rs index aad1805b9adc8..a69c9358667ca 100644 --- a/crates/assistant2/src/context.rs +++ b/crates/assistant2/src/context.rs @@ -1,6 +1,5 @@ use std::path::Path; use std::rc::Rc; -use std::sync::Arc; use file_icons::FileIcons; use gpui::{AppContext, Model, SharedString}; @@ -11,7 +10,7 @@ use text::BufferId; use ui::IconName; use util::post_inc; -use crate::thread::Thread; +use crate::{context_store::buffer_path_log_err, thread::Thread}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] pub struct ContextId(pub(crate) usize); @@ -76,7 +75,7 @@ impl Context { #[derive(Debug)] pub struct FileContext { pub id: ContextId, - pub buffer: ContextBuffer, + pub context_buffer: ContextBuffer, } #[derive(Debug)] @@ -84,7 +83,7 @@ pub struct DirectoryContext { #[allow(unused)] pub path: Rc, #[allow(unused)] - pub buffers: Vec, + pub context_buffers: Vec, pub snapshot: ContextSnapshot, } @@ -108,7 +107,7 @@ pub struct ThreadContext { // TODO: Model holds onto the buffer even if the file is deleted and closed. Should remove // the context from the message editor in this case. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ContextBuffer { #[allow(unused)] pub id: BufferId, @@ -130,18 +129,9 @@ impl Context { } impl FileContext { - pub fn path(&self, cx: &AppContext) -> Option> { - let buffer = self.buffer.buffer.read(cx); - if let Some(file) = buffer.file() { - Some(file.path().clone()) - } else { - log::error!("Buffer that had a path unexpectedly no longer has a path."); - None - } - } - pub fn snapshot(&self, cx: &AppContext) -> Option { - let path = self.path(cx)?; + let buffer = self.context_buffer.buffer.read(cx); + let path = buffer_path_log_err(buffer)?; let full_path: SharedString = path.to_string_lossy().into_owned().into(); let name = match path.file_name() { Some(name) => name.to_string_lossy().into_owned().into(), @@ -161,12 +151,51 @@ impl FileContext { tooltip: Some(full_path), icon_path, kind: ContextKind::File, - text: Box::new([self.buffer.text.clone()]), + text: Box::new([self.context_buffer.text.clone()]), }) } } impl DirectoryContext { + pub fn new( + id: ContextId, + path: &Path, + context_buffers: Vec, + ) -> DirectoryContext { + let full_path: SharedString = path.to_string_lossy().into_owned().into(); + + let name = match path.file_name() { + Some(name) => name.to_string_lossy().into_owned().into(), + None => full_path.clone(), + }; + + let parent = path + .parent() + .and_then(|p| p.file_name()) + .map(|p| p.to_string_lossy().into_owned().into()); + + // TODO: include directory path in text? + let text = context_buffers + .iter() + .map(|b| b.text.clone()) + .collect::>() + .into(); + + DirectoryContext { + path: path.into(), + context_buffers, + snapshot: ContextSnapshot { + id, + name, + parent, + tooltip: Some(full_path), + icon_path: None, + kind: ContextKind::Directory, + text, + }, + } + } + pub fn snapshot(&self) -> ContextSnapshot { self.snapshot.clone() } diff --git a/crates/assistant2/src/context_store.rs b/crates/assistant2/src/context_store.rs index 72f66689c96c1..66f4a7eace1b5 100644 --- a/crates/assistant2/src/context_store.rs +++ b/crates/assistant2/src/context_store.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use anyhow::{anyhow, bail, Result}; use collections::{BTreeMap, HashMap}; +use futures::{self, future, Future, FutureExt}; use gpui::{AppContext, AsyncAppContext, Model, ModelContext, SharedString, Task, WeakView}; use language::Buffer; use project::{ProjectPath, Worktree}; @@ -11,8 +12,8 @@ use text::BufferId; use workspace::Workspace; use crate::context::{ - Context, ContextBuffer, ContextId, ContextKind, ContextSnapshot, DirectoryContext, - FetchedUrlContext, FileContext, ThreadContext, + Context, ContextBuffer, ContextId, ContextSnapshot, DirectoryContext, FetchedUrlContext, + FileContext, ThreadContext, }; use crate::thread::{Thread, ThreadId}; @@ -104,7 +105,7 @@ impl ContextStore { project_path.path.clone(), buffer_model, buffer, - &cx.to_async(), + cx.to_async(), ) })?; @@ -133,7 +134,7 @@ impl ContextStore { file.path().clone(), buffer_model, buffer, - &cx.to_async(), + cx.to_async(), )) })??; @@ -150,10 +151,8 @@ impl ContextStore { pub fn insert_file(&mut self, context_buffer: ContextBuffer) { let id = self.next_context_id.post_inc(); self.files.insert(context_buffer.id, id); - self.context.push(Context::File(FileContext { - id, - buffer: context_buffer, - })); + self.context + .push(Context::File(FileContext { id, context_buffer })); } pub fn add_directory( @@ -207,7 +206,7 @@ impl ContextStore { .collect::>() })?; - let buffers = futures::future::join_all(open_buffer_tasks).await; + let buffers = future::join_all(open_buffer_tasks).await; let mut buffer_infos = Vec::new(); let mut text_tasks = Vec::new(); @@ -216,68 +215,41 @@ impl ContextStore { let buffer_model = buffer_model?; let buffer = buffer_model.read(cx); let (buffer_info, text_task) = - collect_buffer_info_and_text(path, buffer_model, buffer, &cx.to_async()); + collect_buffer_info_and_text(path, buffer_model, buffer, cx.to_async()); buffer_infos.push(buffer_info); text_tasks.push(text_task); } anyhow::Ok(()) })??; - let buffer_texts = futures::future::join_all(text_tasks).await; - let directory_buffers = buffer_infos + let buffer_texts = future::join_all(text_tasks).await; + let context_buffers = buffer_infos .into_iter() - .zip(buffer_texts.iter()) - .map(|(info, text)| make_context_buffer(info, text.clone())) + .zip(buffer_texts) + .map(|(info, text)| make_context_buffer(info, text)) .collect::>(); - if directory_buffers.is_empty() { + if context_buffers.is_empty() { bail!("No text files found in {}", &project_path.path.display()); } - // TODO: include directory path in text? - this.update(&mut cx, |this, _| { - this.insert_directory(&project_path.path, directory_buffers, buffer_texts.into()); + this.insert_directory(&project_path.path, context_buffers); })?; anyhow::Ok(()) }) } - pub fn insert_directory( - &mut self, - path: &Path, - buffers: Vec, - text: Box<[SharedString]>, - ) { + pub fn insert_directory(&mut self, path: &Path, context_buffers: Vec) { let id = self.next_context_id.post_inc(); self.directories.insert(path.to_path_buf(), id); - let full_path: SharedString = path.to_string_lossy().into_owned().into(); - - let name = match path.file_name() { - Some(name) => name.to_string_lossy().into_owned().into(), - None => full_path.clone(), - }; - - let parent = path - .parent() - .and_then(|p| p.file_name()) - .map(|p| p.to_string_lossy().into_owned().into()); - - self.context.push(Context::Directory(DirectoryContext { - path: path.into(), - buffers, - snapshot: ContextSnapshot { - id, - name, - parent, - tooltip: Some(full_path), - icon_path: None, - kind: ContextKind::Directory, - text, - }, - })); + self.context.push(Context::Directory(DirectoryContext::new( + id, + path, + context_buffers, + ))); } pub fn add_thread(&mut self, thread: Model, cx: &mut ModelContext) { @@ -347,7 +319,8 @@ impl ContextStore { if !self.files.is_empty() { let found_file_context = self.context.iter().find(|context| match &context { Context::File(file_context) => { - if let Some(file_path) = file_context.path(cx) { + let buffer = file_context.context_buffer.buffer.read(cx); + if let Some(file_path) = buffer_path_log_err(buffer) { *file_path == *path } else { false @@ -390,6 +363,17 @@ impl ContextStore { pub fn includes_url(&self, url: &str) -> Option { self.fetched_urls.get(url).copied() } + + /// Replaces the context that matches the ID of the new context, if any match. + fn replace_context(&mut self, new_context: Context) { + let id = new_context.id(); + for context in self.context.iter_mut() { + if context.id() == id { + *context = new_context; + break; + } + } + } } pub enum FileInclusion { @@ -417,7 +401,7 @@ fn collect_buffer_info_and_text( path: Arc, buffer_model: Model, buffer: &Buffer, - cx: &AsyncAppContext, + cx: AsyncAppContext, ) -> (BufferInfo, Task) { let buffer_info = BufferInfo { id: buffer.remote_id(), @@ -432,6 +416,15 @@ fn collect_buffer_info_and_text( (buffer_info, text_task) } +pub fn buffer_path_log_err(buffer: &Buffer) -> Option> { + if let Some(file) = buffer.file() { + Some(file.path().clone()) + } else { + log::error!("Buffer that had a path unexpectedly no longer has a path."); + None + } +} + fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString { let path_extension = path.extension().and_then(|ext| ext.to_str()); let path_string = path.to_string_lossy(); @@ -485,3 +478,133 @@ fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec> { files } + +pub fn refresh_context_store_text( + context_store: Model, + cx: &AppContext, +) -> impl Future { + let mut tasks = Vec::new(); + let context_store_ref = context_store.read(cx); + for context in &context_store_ref.context { + match context { + Context::File(file_context) => { + let context_store = context_store.clone(); + if let Some(task) = refresh_file_text(context_store, file_context, cx) { + tasks.push(task); + } + } + Context::Directory(directory_context) => { + let context_store = context_store.clone(); + if let Some(task) = refresh_directory_text(context_store, directory_context, cx) { + tasks.push(task); + } + } + Context::Thread(thread_context) => { + let context_store = context_store.clone(); + tasks.push(refresh_thread_text(context_store, thread_context, cx)); + } + // Intentionally omit refreshing fetched URLs as it doesn't seem all that useful, + // and doing the caching properly could be tricky (unless it's already handled by + // the HttpClient?). + Context::FetchedUrl(_) => {} + } + } + + future::join_all(tasks).map(|_| ()) +} + +fn refresh_file_text( + context_store: Model, + file_context: &FileContext, + cx: &AppContext, +) -> Option> { + let id = file_context.id; + let task = refresh_context_buffer(&file_context.context_buffer, cx); + if let Some(task) = task { + Some(cx.spawn(|mut cx| async move { + let context_buffer = task.await; + context_store + .update(&mut cx, |context_store, _| { + let new_file_context = FileContext { id, context_buffer }; + context_store.replace_context(Context::File(new_file_context)); + }) + .ok(); + })) + } else { + None + } +} + +fn refresh_directory_text( + context_store: Model, + directory_context: &DirectoryContext, + cx: &AppContext, +) -> Option> { + let mut stale = false; + let futures = directory_context + .context_buffers + .iter() + .map(|context_buffer| { + if let Some(refresh_task) = refresh_context_buffer(context_buffer, cx) { + stale = true; + future::Either::Left(refresh_task) + } else { + future::Either::Right(future::ready((*context_buffer).clone())) + } + }) + .collect::>(); + + if !stale { + return None; + } + + let context_buffers = future::join_all(futures); + + let id = directory_context.snapshot.id; + let path = directory_context.path.clone(); + Some(cx.spawn(|mut cx| async move { + let context_buffers = context_buffers.await; + context_store + .update(&mut cx, |context_store, _| { + let new_directory_context = DirectoryContext::new(id, &path, context_buffers); + context_store.replace_context(Context::Directory(new_directory_context)); + }) + .ok(); + })) +} + +fn refresh_thread_text( + context_store: Model, + thread_context: &ThreadContext, + cx: &AppContext, +) -> Task<()> { + let id = thread_context.id; + let thread = thread_context.thread.clone(); + cx.spawn(move |mut cx| async move { + context_store + .update(&mut cx, |context_store, cx| { + let text = thread.read(cx).text().into(); + context_store.replace_context(Context::Thread(ThreadContext { id, thread, text })); + }) + .ok(); + }) +} + +fn refresh_context_buffer( + context_buffer: &ContextBuffer, + cx: &AppContext, +) -> Option> { + let buffer = context_buffer.buffer.read(cx); + let path = buffer_path_log_err(buffer)?; + if buffer.version.changed_since(&context_buffer.version) { + let (buffer_info, text_task) = collect_buffer_info_and_text( + path, + context_buffer.buffer.clone(), + buffer, + cx.to_async(), + ); + Some(text_task.map(move |text| make_context_buffer(buffer_info, text))) + } else { + None + } +} diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index e9a7b4fc8e7dd..a86a9efb3dc65 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -19,7 +19,7 @@ use workspace::Workspace; use crate::assistant_model_selector::AssistantModelSelector; use crate::context_picker::{ConfirmBehavior, ContextPicker}; -use crate::context_store::ContextStore; +use crate::context_store::{refresh_context_store_text, ContextStore}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::thread::{RequestKind, Thread}; use crate::thread_store::ThreadStore; @@ -125,22 +125,20 @@ impl MessageEditor { self.send_to_model(RequestKind::Chat, cx); } - fn send_to_model( - &mut self, - request_kind: RequestKind, - cx: &mut ViewContext, - ) -> Option<()> { + fn send_to_model(&mut self, request_kind: RequestKind, cx: &mut ViewContext) { let provider = LanguageModelRegistry::read_global(cx).active_provider(); if provider .as_ref() .map_or(false, |provider| provider.must_accept_terms(cx)) { cx.notify(); - return None; + return; } let model_registry = LanguageModelRegistry::read_global(cx); - let model = model_registry.active_model()?; + let Some(model) = model_registry.active_model() else { + return; + }; let user_message = self.editor.update(cx, |editor, cx| { let text = editor.text(cx); @@ -148,29 +146,37 @@ impl MessageEditor { text }); - let thread = self.thread.clone(); - thread.update(cx, |thread, cx| { - let context = self.context_store.read(cx).snapshot(cx).collect::>(); - thread.insert_user_message(user_message, context, cx); - let mut request = thread.to_completion_request(request_kind, cx); + let refresh_task = refresh_context_store_text(self.context_store.clone(), cx); - if self.use_tools { - request.tools = thread - .tools() - .tools(cx) - .into_iter() - .map(|tool| LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema: tool.input_schema(), - }) - .collect(); - } + let thread = self.thread.clone(); + let context_store = self.context_store.clone(); + let use_tools = self.use_tools; + cx.spawn(move |_, mut cx| async move { + refresh_task.await; + thread + .update(&mut cx, |thread, cx| { + let context = context_store.read(cx).snapshot(cx).collect::>(); + thread.insert_user_message(user_message, context, cx); + let mut request = thread.to_completion_request(request_kind, cx); - thread.stream_completion(request, model, cx) - }); + if use_tools { + request.tools = thread + .tools() + .tools(cx) + .into_iter() + .map(|tool| LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema: tool.input_schema(), + }) + .collect(); + } - None + thread.stream_completion(request, model, cx) + }) + .ok(); + }) + .detach(); } fn handle_editor_event(