diff --git a/src/onboarding/commands.rs b/src/onboarding/commands.rs index edbe7ebf..88c2f1ce 100644 --- a/src/onboarding/commands.rs +++ b/src/onboarding/commands.rs @@ -1,17 +1,16 @@ +use anyhow::Result; use serde_json::json; use crate::onboarding::util::prompt_text; -pub fn command_questions(context: &mut tera::Context) { +pub fn command_questions(context: &mut tera::Context) -> Result<()> { println!("\nKwaak agents can run tests and use code coverage when coding. Kwaak uses tests as an extra feedback moment for agents"); - let test_command = prompt_text("Test command (optional, to skip)", None) - .prompt_skippable() - .unwrap(); + let test_command = + prompt_text("Test command (optional, to skip)", None).prompt_skippable()?; - let coverage_command = prompt_text("Coverage command (optional, to skip)", None) - .prompt_skippable() - .unwrap(); + let coverage_command = + prompt_text("Coverage command (optional, to skip)", None).prompt_skippable()?; context.insert( "commands", @@ -20,4 +19,6 @@ pub fn command_questions(context: &mut tera::Context) { "coverage": coverage_command, }), ); + + Ok(()) } diff --git a/src/onboarding/git.rs b/src/onboarding/git.rs index 14966cd7..72c5dcab 100644 --- a/src/onboarding/git.rs +++ b/src/onboarding/git.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use serde_json::json; use crate::{ @@ -5,49 +6,36 @@ use crate::{ onboarding::util::{prompt_api_key, prompt_text}, }; -pub fn git_questions(context: &mut tera::Context) { +pub fn git_questions(context: &mut tera::Context) -> Result<()> { let (default_owner, default_repository) = default_owner_and_repo().unzip(); let default_branch = default_main_branch(); - let branch_input = prompt_text("Default git branch", Some(&default_branch)) - .prompt() - .unwrap(); - - println!("\nWith a github token, Kwaak can create pull requests, search github code, and automatically push to the remote."); - let github_api_key = prompt_api_key( - "GitHub api key (optional, to skip)", - Some("env:GITHUB_TOKEN"), - ) - .prompt_skippable() - .unwrap(); - - let auto_push_remote = + let branch_input = prompt_text("Default git branch", Some(&default_branch)).prompt()?; + + println!("\nWith a github token, Kwaak can create pull requests, search github code, and automatically push to the remote. Kwaak will never push to the main branch."); + + let github_api_key = prompt_api_key("Github token (optional, to skip)", None) + .with_placeholder("env:GITHUB_token") + .prompt_skippable()?; + + let auto_push_remote = if github_api_key.is_some() { inquire::Confirm::new("Push to git remote after changes? (requires github token)") - .with_default(github_api_key.is_some()) - .prompt() - .unwrap(); - - let owner_input = prompt_text( - "Git owner (optional, to skip)", - default_owner.as_deref(), - ) - .prompt_skippable() - .unwrap(); - let repository_input = prompt_text( - "Git repository (optional, to skip)", - default_repository.as_deref(), - ) - .prompt_skippable() - .unwrap(); + .with_default(false) + .prompt()? + } else { + false + }; context.insert("github_api_key", &github_api_key); context.insert( "git", &json!({ - "owner": owner_input, - "repository": repository_input, + "owner": default_owner, + "repository": default_repository, "main_branch": branch_input, "auto_push_remote": auto_push_remote, }), ); + + Ok(()) } diff --git a/src/onboarding/llm.rs b/src/onboarding/llm.rs index eae630f8..8c3b812e 100644 --- a/src/onboarding/llm.rs +++ b/src/onboarding/llm.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use std::collections::HashMap; use serde_json::json; @@ -10,7 +11,8 @@ use crate::{ use super::util::{prompt_api_key, prompt_select}; -pub async fn llm_questions(context: &mut tera::Context) { +pub async fn llm_questions(context: &mut tera::Context) -> Result<()> { + println!("\nKwaak supports multiple LLM providers and uses multiple models for various tasks. What providers would you like to use?"); let valid_llms = LLMConfiguration::VARIANTS .iter() .map(AsRef::as_ref) // Kinda weird that we need to do this @@ -21,14 +23,13 @@ pub async fn llm_questions(context: &mut tera::Context) { "What LLM would you like to use?", valid_llms, Some("OpenAI"), - ) - .parse() - .unwrap(); + )? + .parse()?; match valid_llm { - LLMConfiguration::OpenAI { .. } => openai_questions(context), - LLMConfiguration::Ollama { .. } => ollama_questions(context), - LLMConfiguration::OpenRouter { .. } => open_router_questions(context).await, + LLMConfiguration::OpenAI { .. } => openai_questions(context)?, + LLMConfiguration::Ollama { .. } => ollama_questions(context)?, + LLMConfiguration::OpenRouter { .. } => open_router_questions(context).await?, LLMConfiguration::FastEmbed { .. } => { println!("{valid_llm} is not selectable yet, skipping configuration"); } @@ -37,31 +38,32 @@ pub async fn llm_questions(context: &mut tera::Context) { println!("{valid_llm} is not meant for production use, skipping configuration"); } } + + Ok(()) } -fn openai_questions(context: &mut tera::Context) { +fn openai_questions(context: &mut tera::Context) -> Result<()> { let api_key = prompt_api_key( "Where can we find your OpenAI api key? (https://platform.openai.com/api-keys)", Some("env:OPENAI_API_KEY"), ) - .prompt() - .unwrap(); + .prompt()?; let indexing_model = prompt_select( "Model used for fast operations (like indexing)", OpenAIPromptModel::VARIANTS.to_vec(), Some("gpt-4o-mini"), - ); + )?; let query_model = prompt_select( "Model used for querying and code generation", OpenAIPromptModel::VARIANTS.to_vec(), Some("gpt-4o"), - ); + )?; let embedding_model = prompt_select( "Model used for embeddings", OpenAIEmbeddingModel::VARIANTS.to_vec(), Some("text-embedding-3-large"), - ); + )?; context.insert("openai_api_key", &api_key); context.insert( @@ -82,6 +84,8 @@ fn openai_questions(context: &mut tera::Context) { "base_url": None::, }), ); + + Ok(()) } async fn get_open_router_models() -> Option>> { @@ -102,15 +106,14 @@ async fn get_open_router_models() -> Option Result<()> { println!("\nOpenRouter allows you to use a variety of managed models via a single api. You can find models at https://openrouter.ai/models."); let api_key = prompt_api_key( "Where can we find your OpenRouter api key? (https://openrouter.ai/settings/keys)", Some("env:OPEN_ROUTER_API_KEY"), ) - .prompt() - .unwrap(); + .prompt()?; let openrouter_models = get_open_router_models().await; @@ -143,16 +146,14 @@ async fn open_router_questions(context: &mut tera::Context) { ) .with_autocomplete(autocompletion.clone()) .with_validator(validator.clone()) - .prompt() - .unwrap(); + .prompt()?; let query_model = prompt_text( "Model used for querying and code generation", Some("anthropic/claude-3.5-sonnet"), ) .with_autocomplete(autocompletion.clone()) .with_validator(validator.clone()) - .prompt() - .unwrap(); + .prompt()?; context.insert("open_router_api_key", &api_key); @@ -167,28 +168,26 @@ async fn open_router_questions(context: &mut tera::Context) { ); println!("\nOpenRouter does not support embeddings yet. Currently we suggest to use FastEmbed. If you want to use a different provider you can change it in your config later."); - fastembed_questions(context); + fastembed_questions(context) } -fn ollama_questions(context: &mut tera::Context) { +fn ollama_questions(context: &mut tera::Context) -> Result<()> { println!("Note that you need to have a running Ollama instance."); let indexing_model = prompt_text( "Model used for fast operations (like indexing). This model does not need to support tool calls.", None - ).prompt().unwrap(); + ).prompt()?; let query_model = prompt_text( "Model used for querying and code generation. This model needs to support tool calls.", None, ) - .prompt() - .unwrap(); + .prompt()?; - let embedding_model = prompt_text("Model used for embeddings, bge-m3 is a solid choice", None) - .prompt() - .unwrap(); + let embedding_model = + prompt_text("Model used for embeddings, bge-m3 is a solid choice", None).prompt()?; let vector_size = inquire::Text::new("Vector size for the embedding model") .with_validator(|input: &str| match input.parse::() { @@ -197,8 +196,7 @@ fn ollama_questions(context: &mut tera::Context) { "Invalid number".into(), )), }) - .prompt() - .unwrap(); + .prompt()?; let base_url = inquire::Text::new("Custom base url? (optional, to skip)") .with_validator(|input: &str| match url::Url::parse(input) { @@ -207,8 +205,7 @@ fn ollama_questions(context: &mut tera::Context) { "Invalid URL".into(), )), }) - .prompt_skippable() - .unwrap(); + .prompt_skippable()?; context.insert( "llm", @@ -228,18 +225,19 @@ fn ollama_questions(context: &mut tera::Context) { "vector_size": vector_size, }), ); + + Ok(()) } -pub fn fastembed_questions(context: &mut tera::Context) { +pub fn fastembed_questions(context: &mut tera::Context) -> Result<()> { println!("\nFastEmbed provides embeddings that are generated quickly locally. Unless you have a specific need for a different model, the default is a good choice."); let embedding_model: FastembedModel = prompt_select( "Embedding model", FastembedModel::list_supported_models(), Some(FastembedModel::default().to_string()), - ) - .parse() - .unwrap(); + )? + .parse()?; context.insert( "embed_llm", @@ -249,6 +247,8 @@ pub fn fastembed_questions(context: &mut tera::Context) { "base_url": None::, }), ); + + Ok(()) } #[derive(Clone)] diff --git a/src/onboarding/mod.rs b/src/onboarding/mod.rs index e1aa28ec..3138d183 100644 --- a/src/onboarding/mod.rs +++ b/src/onboarding/mod.rs @@ -44,10 +44,10 @@ pub async fn run(file: Option, dry_run: bool) -> Result<()> { println!("We have a few questions to ask you to get started, you can always change these later in the `{}` file.", file.display()); let mut context = tera::Context::new(); - project_questions(&mut context); - git_questions(&mut context); - llm_questions(&mut context).await; - command_questions(&mut context); + project_questions(&mut context)?; + git_questions(&mut context)?; + llm_questions(&mut context).await?; + command_questions(&mut context)?; let config = Templates::render("kwaak.toml", &context).context("Failed to render default config")?; diff --git a/src/onboarding/project.rs b/src/onboarding/project.rs index 6059911a..039cb68d 100644 --- a/src/onboarding/project.rs +++ b/src/onboarding/project.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use strum::IntoEnumIterator; use swiftide::integrations::treesitter::SupportedLanguages; @@ -5,11 +6,9 @@ use crate::config::defaults::default_project_name; use super::util::{prompt_select, prompt_text}; -pub fn project_questions(context: &mut tera::Context) { +pub fn project_questions(context: &mut tera::Context) -> Result<()> { let project_name = default_project_name(); - let project_name_input = prompt_text("Project name", Some(&project_name)) - .prompt() - .unwrap(); + let project_name_input = prompt_text("Project name", Some(&project_name)).prompt()?; context.insert("project_name", &project_name_input); // Get user inputs with defaults @@ -18,9 +17,11 @@ pub fn project_questions(context: &mut tera::Context) { .map(|l| l.to_string()) .collect::>(); - let language_input = prompt_select("Programming language", options.clone(), detected); + let language_input = prompt_select("Programming language", options.clone(), detected)?; context.insert("language", &language_input); + + Ok(()) } fn naive_lang_detect() -> Option { diff --git a/src/onboarding/util.rs b/src/onboarding/util.rs index 8d344d84..2ad97476 100644 --- a/src/onboarding/util.rs +++ b/src/onboarding/util.rs @@ -1,3 +1,5 @@ +use anyhow::Result; + pub fn prompt_text<'a>(prompt: &'a str, default: Option<&'a str>) -> inquire::Text<'a> { let mut prompt = inquire::Text::new(prompt); @@ -27,7 +29,7 @@ pub fn prompt_api_key<'a>(prompt: &'a str, default: Option<&'a str>) -> inquire: } #[allow(clippy::needless_pass_by_value)] -pub fn prompt_select(prompt: &str, options: Vec, default: Option) -> String +pub fn prompt_select(prompt: &str, options: Vec, default: Option) -> Result where T: std::fmt::Display + std::cmp::PartialEq + Clone, { @@ -49,5 +51,5 @@ where } } - prompt.prompt().unwrap().to_string() + Ok(prompt.prompt()?.to_string()) } diff --git a/tests/cli_init.rs b/tests/cli_init.rs index 3fad6cf9..b099cd10 100644 --- a/tests/cli_init.rs +++ b/tests/cli_init.rs @@ -106,16 +106,14 @@ async fn test_interactive_default_init() { let mut p = spawn(&format!("{cmd:?} init --dry-run"), Some(30_000)).unwrap(); + let mut set_github_token = false; while let Ok(line) = p.read_line() { println!("{line}"); - // if line.contains("Dry run, would have written") { - // break; - // } - if line.contains("base url") { - let _ = p.send_line("https://api.bosun.ai"); - } else { - let _ = p.send_line(""); + if line.contains("Github token (optional") && !set_github_token { + let _ = p.send_line("env:GITHUB_TOKEN"); + set_github_token = true; } + let _ = p.send_line(""); } println!("{}", p.exp_eof().unwrap());