From 20ddbe7237844fb1e8db6c6b5eed3a1a157d2d59 Mon Sep 17 00:00:00 2001 From: Timon Vonk Date: Sun, 26 Jan 2025 14:40:44 +0100 Subject: [PATCH] Add rudimentary support for different agent models --- src/agent/agent_factory.rs | 0 src/agent/mod.rs | 43 +++++++- src/agent/util.rs | 179 +++++++++++++++++++++++++++++++++ src/agent/v1.rs | 199 +++---------------------------------- src/config/config.rs | 10 ++ 5 files changed, 241 insertions(+), 190 deletions(-) create mode 100644 src/agent/agent_factory.rs create mode 100644 src/agent/util.rs diff --git a/src/agent/agent_factory.rs b/src/agent/agent_factory.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/agent/mod.rs b/src/agent/mod.rs index c6396bf0..a2652b80 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -3,10 +3,47 @@ mod env_setup; mod running_agent; mod tool_summarizer; pub mod tools; +mod util; mod v1; +use std::sync::Arc; -pub use v1::start_agent; +use anyhow::Result; +use swiftide_core::Tool; + +/// NOTE: On architecture, when more agents are added, it would be nice to have the concept of an +/// (Agent/Chat) session that wraps all this complexity => Responders then update on the session. +/// Makes everything a lot simpler. The session can then also references the running agent, +/// executor, etc + +#[tracing::instrument(skip(repository, command_responder))] +pub async fn start_agent( + uuid: Uuid, + repository: &Repository, + initial_query: &str, + command_responder: Arc, +) -> Result { + command_responder.update("starting up agent for the first time, this might take a while"); + + match repository.config().agent { + crate::config::SupportedAgents::V1 => { + v1::start(initial_query, uuid, repository, command_responder).await + } + } +} + +pub fn available_tools( + repository: &Repository, + github_session: Option<&Arc>, + agent_env: Option<&env_setup::AgentEnvironment>, +) -> Result>> { + match repository.config().agent { + crate::config::SupportedAgents::V1 => { + v1::available_tools(repository, github_session, agent_env) + } + } +} -// Available so it's easy to debug tools in the cli pub use running_agent::RunningAgent; -pub use v1::available_tools; +use uuid::Uuid; + +use crate::{commands::Responder, git::github::GithubSession, repository::Repository}; diff --git a/src/agent/util.rs b/src/agent/util.rs new file mode 100644 index 00000000..e968b192 --- /dev/null +++ b/src/agent/util.rs @@ -0,0 +1,179 @@ +use anyhow::Context as _; +use anyhow::Result; +use swiftide_core::SimplePrompt; +use uuid::Uuid; + +use crate::commands::Responder; + +pub async fn rename_chat( + query: &str, + fast_query_provider: &dyn SimplePrompt, + command_responder: &dyn Responder, +) -> Result<()> { + let chat_name = fast_query_provider + .prompt( + format!("Give a good, short, max 60 chars title for the following query. Only respond with the title.:\n{query}") + .into(), + ) + .await + .context("Could not get chat name")? + .trim_matches('"') + .chars() + .take(60) + .collect::(); + + command_responder.rename_chat(&chat_name); + + Ok(()) +} + +pub async fn create_branch_name( + query: &str, + uuid: &Uuid, + fast_query_provider: &dyn SimplePrompt, + command_responder: &dyn Responder, +) -> Result { + let name = fast_query_provider + .prompt( + format!("Give a good, short, max 30 chars git-branch-name for the following query. Only respond with the git-branch-name.:\n{query}") + .into(), + ) + .await + .context("Could not get chat name")? + .trim_matches('"') + .chars() + .take(30) + .collect::(); + + // only keep ascii characters + let name = name.chars().filter(char::is_ascii).collect::(); + let name = name.to_lowercase(); + + // replace all non-alphanumeric characters with dashes + let name = name + .chars() + .map(|c| if c.is_alphanumeric() { c } else { '-' }) + .collect::(); + + // get the first 8 characters of the uuid + let uuid_start = uuid.to_string().chars().take(8).collect::(); + let branch_name = format!("kwaak/{name}-{uuid_start}"); + + command_responder.rename_branch(&branch_name); + + Ok(branch_name) +} + +#[cfg(test)] +mod tests { + use swiftide_core::MockSimplePrompt; + + use crate::commands::MockResponder; + use mockall::{predicate, PredicateBooleanExt}; + + use super::*; + + #[tokio::test] + async fn test_rename_chat() { + let query = "This is a query"; + let mut llm_mock = MockSimplePrompt::new(); + llm_mock + .expect_prompt() + .returning(|_| Ok("Excellent title".to_string())); + + let mut mock_responder = MockResponder::default(); + + mock_responder + .expect_rename_chat() + .with(predicate::eq("Excellent title")) + .once() + .returning(|_| ()); + + rename_chat(&query, &llm_mock as &dyn SimplePrompt, &mock_responder) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_rename_chat_limits_60() { + let query = "This is a query"; + let mut llm_mock = MockSimplePrompt::new(); + llm_mock + .expect_prompt() + .returning(|_| Ok("Excellent title".repeat(100).to_string())); + + let mut mock_responder = MockResponder::default(); + + mock_responder + .expect_rename_chat() + .with( + predicate::str::starts_with("Excellent title") + .and(predicate::function(|s: &str| s.len() == 60)), + ) + .once() + .returning(|_| ()); + + rename_chat(&query, &llm_mock as &dyn SimplePrompt, &mock_responder) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_rename_branch() { + let query = "This is a query"; + let mut llm_mock = MockSimplePrompt::new(); + llm_mock + .expect_prompt() + .returning(|_| Ok("excellent-name".to_string())); + + let mut mock_responder = MockResponder::default(); + let fixed_uuid = Uuid::parse_str("936DA01F9ADD4d9d80C702AF85C822A8").unwrap(); + + mock_responder + .expect_rename_branch() + .with(predicate::str::starts_with("kwaak/excellent-name")) + .once() + .returning(|_| ()); + + create_branch_name( + &query, + &fixed_uuid, + &llm_mock as &dyn SimplePrompt, + &mock_responder, + ) + .await + .unwrap(); + } + + // NOTE the prompt is intended to be limited to 30 characters, but the branch name in total + // has 15 more characters (total 45): "kwaak/" + "-" + 8 characters from the uuid + #[tokio::test] + async fn test_rename_branch_limits_45() { + let query = "This is a query"; + let mut llm_mock = MockSimplePrompt::new(); + llm_mock + .expect_prompt() + .returning(|_| Ok("excellent-name".repeat(100).to_string())); + + let mut mock_responder = MockResponder::default(); + let fixed_uuid = Uuid::parse_str("936DA01F9ADD4d9d80C702AF85C822A8").unwrap(); + + mock_responder + .expect_rename_branch() + .with( + predicate::str::starts_with("kwaak/excellent-name") + .and(predicate::function(|s: &str| s.len() == 45)), + ) + .once() + .returning(|_| ()); + + create_branch_name( + &query, + &fixed_uuid, + &llm_mock as &dyn SimplePrompt, + &mock_responder, + ) + .await + .unwrap(); + } +} diff --git a/src/agent/v1.rs b/src/agent/v1.rs index 28a93503..1ae9445d 100644 --- a/src/agent/v1.rs +++ b/src/agent/v1.rs @@ -19,8 +19,8 @@ use super::{ tools, RunningAgent, }; use crate::{ - commands::Responder, config::SupportedToolExecutors, git::github::GithubSession, indexing, - repository::Repository, util::accept_non_zero_exit, + agent::util, commands::Responder, config::SupportedToolExecutors, git::github::GithubSession, + indexing, repository::Repository, util::accept_non_zero_exit, }; use swiftide_docker_executor::DockerExecutor; @@ -83,7 +83,6 @@ async fn start_tool_executor(uuid: Uuid, repository: &Repository) -> Result Result, ) -> Result { - command_responder.update("starting up agent for the first time, this might take a while"); - let query_provider: Box = repository.config().query_provider().try_into()?; let fast_query_provider: Box = @@ -122,18 +118,20 @@ pub async fn start_agent( None => None, }; - let system_prompt = build_system_prompt(&repository)?; + // TODO: Feels a bit off to have EnvSetup return an Env, just to pass it to tool creation to + // get the ref/branch name + // + // Probably nicer to have a `ChatSession` or `AgentSession` that encapsulates all the + // complexity let ((), branch_name, executor, initial_context) = tokio::try_join!( - rename_chat(&query, &fast_query_provider, &command_responder), - create_branch_name(&query, &uuid, &fast_query_provider, &command_responder), + util::rename_chat(&query, &fast_query_provider, &command_responder), + util::create_branch_name(&query, &uuid, &fast_query_provider, &command_responder), start_tool_executor(uuid, &repository), generate_initial_context(&repository, query) )?; - let env_setup = EnvSetup::new(&repository, github_session.as_deref(), &*executor); - // TODO: Feels a bit off to have EnvSetup return an Env, just to pass it to tool creation to - // get the ref/branch name let agent_env = env_setup.exec_setup_commands(branch_name).await?; + let system_prompt = build_system_prompt(&repository)?; let tools = available_tools(&repository, github_session.as_ref(), Some(&agent_env))?; @@ -322,176 +320,3 @@ fn build_system_prompt(repository: &Repository) -> Result { Ok(prompt) } - -async fn rename_chat( - query: &str, - fast_query_provider: &dyn SimplePrompt, - command_responder: &dyn Responder, -) -> Result<()> { - let chat_name = fast_query_provider - .prompt( - format!("Give a good, short, max 60 chars title for the following query. Only respond with the title.:\n{query}") - .into(), - ) - .await - .context("Could not get chat name")? - .trim_matches('"') - .chars() - .take(60) - .collect::(); - - command_responder.rename_chat(&chat_name); - - Ok(()) -} - -async fn create_branch_name( - query: &str, - uuid: &Uuid, - fast_query_provider: &dyn SimplePrompt, - command_responder: &dyn Responder, -) -> Result { - let name = fast_query_provider - .prompt( - format!("Give a good, short, max 30 chars git-branch-name for the following query. Only respond with the git-branch-name.:\n{query}") - .into(), - ) - .await - .context("Could not get chat name")? - .trim_matches('"') - .chars() - .take(30) - .collect::(); - - // only keep ascii characters - let name = name.chars().filter(char::is_ascii).collect::(); - let name = name.to_lowercase(); - - // replace all non-alphanumeric characters with dashes - let name = name - .chars() - .map(|c| if c.is_alphanumeric() { c } else { '-' }) - .collect::(); - - // get the first 8 characters of the uuid - let uuid_start = uuid.to_string().chars().take(8).collect::(); - let branch_name = format!("kwaak/{name}-{uuid_start}"); - - command_responder.rename_branch(&branch_name); - - Ok(branch_name) -} - -#[cfg(test)] -mod tests { - use swiftide_core::MockSimplePrompt; - - use crate::commands::MockResponder; - use mockall::{predicate, PredicateBooleanExt}; - - use super::*; - - #[tokio::test] - async fn test_rename_chat() { - let query = "This is a query"; - let mut llm_mock = MockSimplePrompt::new(); - llm_mock - .expect_prompt() - .returning(|_| Ok("Excellent title".to_string())); - - let mut mock_responder = MockResponder::default(); - - mock_responder - .expect_rename_chat() - .with(predicate::eq("Excellent title")) - .once() - .returning(|_| ()); - - rename_chat(&query, &llm_mock as &dyn SimplePrompt, &mock_responder) - .await - .unwrap(); - } - - #[tokio::test] - async fn test_rename_chat_limits_60() { - let query = "This is a query"; - let mut llm_mock = MockSimplePrompt::new(); - llm_mock - .expect_prompt() - .returning(|_| Ok("Excellent title".repeat(100).to_string())); - - let mut mock_responder = MockResponder::default(); - - mock_responder - .expect_rename_chat() - .with( - predicate::str::starts_with("Excellent title") - .and(predicate::function(|s: &str| s.len() == 60)), - ) - .once() - .returning(|_| ()); - - rename_chat(&query, &llm_mock as &dyn SimplePrompt, &mock_responder) - .await - .unwrap(); - } - - #[tokio::test] - async fn test_rename_branch() { - let query = "This is a query"; - let mut llm_mock = MockSimplePrompt::new(); - llm_mock - .expect_prompt() - .returning(|_| Ok("excellent-name".to_string())); - - let mut mock_responder = MockResponder::default(); - let fixed_uuid = Uuid::parse_str("936DA01F9ADD4d9d80C702AF85C822A8").unwrap(); - - mock_responder - .expect_rename_branch() - .with(predicate::str::starts_with("kwaak/excellent-name")) - .once() - .returning(|_| ()); - - create_branch_name( - &query, - &fixed_uuid, - &llm_mock as &dyn SimplePrompt, - &mock_responder, - ) - .await - .unwrap(); - } - - // NOTE the prompt is intended to be limited to 30 characters, but the branch name in total - // has 15 more characters (total 45): "kwaak/" + "-" + 8 characters from the uuid - #[tokio::test] - async fn test_rename_branch_limits_45() { - let query = "This is a query"; - let mut llm_mock = MockSimplePrompt::new(); - llm_mock - .expect_prompt() - .returning(|_| Ok("excellent-name".repeat(100).to_string())); - - let mut mock_responder = MockResponder::default(); - let fixed_uuid = Uuid::parse_str("936DA01F9ADD4d9d80C702AF85C822A8").unwrap(); - - mock_responder - .expect_rename_branch() - .with( - predicate::str::starts_with("kwaak/excellent-name") - .and(predicate::function(|s: &str| s.len() == 45)), - ) - .once() - .returning(|_| ()); - - create_branch_name( - &query, - &fixed_uuid, - &llm_mock as &dyn SimplePrompt, - &mock_responder, - ) - .await - .unwrap(); - } -} diff --git a/src/config/config.rs b/src/config/config.rs index 1f286d30..01d70fa3 100644 --- a/src/config/config.rs +++ b/src/config/config.rs @@ -26,6 +26,10 @@ pub struct Config { #[serde(default = "default_log_dir")] pub log_dir: PathBuf, + /// The agent model to use by default in chats + #[serde(default)] + pub agent: SupportedAgents, + #[serde(default)] /// Concurrency for indexing /// By default for IO-bound LLMs, we assume 4x the number of CPUs @@ -82,6 +86,12 @@ fn default_otel_enabled() -> bool { false } +#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Default)] +pub enum SupportedAgents { + #[default] + V1, +} + #[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Default)] #[serde(rename_all = "kebab-case")] pub enum SupportedToolExecutors {