Skip to content

Commit

Permalink
feat: Onboarding flow improvements (#265)
Browse files Browse the repository at this point in the history
Makes some small improvements to the onboarding flow

- **Remove owner/repo name detection and make auto push to remote false
by default**
- **Add a header for the llm part**
- **Remove all panics when onboarding**
  • Loading branch information
timonv authored Feb 3, 2025
1 parent 7ff831a commit 9d31200
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 93 deletions.
15 changes: 8 additions & 7 deletions src/onboarding/commands.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use anyhow::Result;
use serde_json::json;

use crate::onboarding::util::prompt_text;

pub fn command_questions(context: &mut tera::Context) {
pub fn command_questions(context: &mut tera::Context) -> Result<()> {
println!("\nKwaak agents can run tests and use code coverage when coding. Kwaak uses tests as an extra feedback moment for agents");

let test_command = prompt_text("Test command (optional, <esc> to skip)", None)
.prompt_skippable()
.unwrap();
let test_command =
prompt_text("Test command (optional, <esc> to skip)", None).prompt_skippable()?;

let coverage_command = prompt_text("Coverage command (optional, <esc> to skip)", None)
.prompt_skippable()
.unwrap();
let coverage_command =
prompt_text("Coverage command (optional, <esc> to skip)", None).prompt_skippable()?;

context.insert(
"commands",
Expand All @@ -20,4 +19,6 @@ pub fn command_questions(context: &mut tera::Context) {
"coverage": coverage_command,
}),
);

Ok(())
}
52 changes: 20 additions & 32 deletions src/onboarding/git.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,41 @@
use anyhow::Result;
use serde_json::json;

use crate::{
config::defaults::{default_main_branch, default_owner_and_repo},
onboarding::util::{prompt_api_key, prompt_text},
};

pub fn git_questions(context: &mut tera::Context) {
pub fn git_questions(context: &mut tera::Context) -> Result<()> {
let (default_owner, default_repository) = default_owner_and_repo().unzip();
let default_branch = default_main_branch();
let branch_input = prompt_text("Default git branch", Some(&default_branch))
.prompt()
.unwrap();

println!("\nWith a github token, Kwaak can create pull requests, search github code, and automatically push to the remote.");
let github_api_key = prompt_api_key(
"GitHub api key (optional, <esc> to skip)",
Some("env:GITHUB_TOKEN"),
)
.prompt_skippable()
.unwrap();

let auto_push_remote =
let branch_input = prompt_text("Default git branch", Some(&default_branch)).prompt()?;

println!("\nWith a github token, Kwaak can create pull requests, search github code, and automatically push to the remote. Kwaak will never push to the main branch.");

let github_api_key = prompt_api_key("Github token (optional, <esc> to skip)", None)
.with_placeholder("env:GITHUB_token")
.prompt_skippable()?;

let auto_push_remote = if github_api_key.is_some() {
inquire::Confirm::new("Push to git remote after changes? (requires github token)")
.with_default(github_api_key.is_some())
.prompt()
.unwrap();

let owner_input = prompt_text(
"Git owner (optional, <esc> to skip)",
default_owner.as_deref(),
)
.prompt_skippable()
.unwrap();
let repository_input = prompt_text(
"Git repository (optional, <esc> to skip)",
default_repository.as_deref(),
)
.prompt_skippable()
.unwrap();
.with_default(false)
.prompt()?
} else {
false
};

context.insert("github_api_key", &github_api_key);
context.insert(
"git",
&json!({
"owner": owner_input,
"repository": repository_input,
"owner": default_owner,
"repository": default_repository,
"main_branch": branch_input,
"auto_push_remote": auto_push_remote,

}),
);

Ok(())
}
72 changes: 36 additions & 36 deletions src/onboarding/llm.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use anyhow::Result;
use std::collections::HashMap;

use serde_json::json;
Expand All @@ -10,7 +11,8 @@ use crate::{

use super::util::{prompt_api_key, prompt_select};

pub async fn llm_questions(context: &mut tera::Context) {
pub async fn llm_questions(context: &mut tera::Context) -> Result<()> {
println!("\nKwaak supports multiple LLM providers and uses multiple models for various tasks. What providers would you like to use?");
let valid_llms = LLMConfiguration::VARIANTS
.iter()
.map(AsRef::as_ref) // Kinda weird that we need to do this
Expand All @@ -21,14 +23,13 @@ pub async fn llm_questions(context: &mut tera::Context) {
"What LLM would you like to use?",
valid_llms,
Some("OpenAI"),
)
.parse()
.unwrap();
)?
.parse()?;

match valid_llm {
LLMConfiguration::OpenAI { .. } => openai_questions(context),
LLMConfiguration::Ollama { .. } => ollama_questions(context),
LLMConfiguration::OpenRouter { .. } => open_router_questions(context).await,
LLMConfiguration::OpenAI { .. } => openai_questions(context)?,
LLMConfiguration::Ollama { .. } => ollama_questions(context)?,
LLMConfiguration::OpenRouter { .. } => open_router_questions(context).await?,
LLMConfiguration::FastEmbed { .. } => {
println!("{valid_llm} is not selectable yet, skipping configuration");
}
Expand All @@ -37,31 +38,32 @@ pub async fn llm_questions(context: &mut tera::Context) {
println!("{valid_llm} is not meant for production use, skipping configuration");
}
}

Ok(())
}

fn openai_questions(context: &mut tera::Context) {
fn openai_questions(context: &mut tera::Context) -> Result<()> {
let api_key = prompt_api_key(
"Where can we find your OpenAI api key? (https://platform.openai.com/api-keys)",
Some("env:OPENAI_API_KEY"),
)
.prompt()
.unwrap();
.prompt()?;
let indexing_model = prompt_select(
"Model used for fast operations (like indexing)",
OpenAIPromptModel::VARIANTS.to_vec(),
Some("gpt-4o-mini"),
);
)?;
let query_model = prompt_select(
"Model used for querying and code generation",
OpenAIPromptModel::VARIANTS.to_vec(),
Some("gpt-4o"),
);
)?;

let embedding_model = prompt_select(
"Model used for embeddings",
OpenAIEmbeddingModel::VARIANTS.to_vec(),
Some("text-embedding-3-large"),
);
)?;

context.insert("openai_api_key", &api_key);
context.insert(
Expand All @@ -82,6 +84,8 @@ fn openai_questions(context: &mut tera::Context) {
"base_url": None::<String>,
}),
);

Ok(())
}

async fn get_open_router_models() -> Option<Vec<HashMap<String, serde_json::Value>>> {
Expand All @@ -102,15 +106,14 @@ async fn get_open_router_models() -> Option<Vec<HashMap<String, serde_json::Valu
response.json().await.ok()?;
models.get("data").map(Vec::to_owned)
}
async fn open_router_questions(context: &mut tera::Context) {
async fn open_router_questions(context: &mut tera::Context) -> Result<()> {
println!("\nOpenRouter allows you to use a variety of managed models via a single api. You can find models at https://openrouter.ai/models.");

let api_key = prompt_api_key(
"Where can we find your OpenRouter api key? (https://openrouter.ai/settings/keys)",
Some("env:OPEN_ROUTER_API_KEY"),
)
.prompt()
.unwrap();
.prompt()?;

let openrouter_models = get_open_router_models().await;

Expand Down Expand Up @@ -143,16 +146,14 @@ async fn open_router_questions(context: &mut tera::Context) {
)
.with_autocomplete(autocompletion.clone())
.with_validator(validator.clone())
.prompt()
.unwrap();
.prompt()?;
let query_model = prompt_text(
"Model used for querying and code generation",
Some("anthropic/claude-3.5-sonnet"),
)
.with_autocomplete(autocompletion.clone())
.with_validator(validator.clone())
.prompt()
.unwrap();
.prompt()?;

context.insert("open_router_api_key", &api_key);

Expand All @@ -167,28 +168,26 @@ async fn open_router_questions(context: &mut tera::Context) {
);

println!("\nOpenRouter does not support embeddings yet. Currently we suggest to use FastEmbed. If you want to use a different provider you can change it in your config later.");
fastembed_questions(context);
fastembed_questions(context)
}

fn ollama_questions(context: &mut tera::Context) {
fn ollama_questions(context: &mut tera::Context) -> Result<()> {
println!("Note that you need to have a running Ollama instance.");

let indexing_model = prompt_text(
"Model used for fast operations (like indexing). This model does not need to support tool calls.",
None

).prompt().unwrap();
).prompt()?;

let query_model = prompt_text(
"Model used for querying and code generation. This model needs to support tool calls.",
None,
)
.prompt()
.unwrap();
.prompt()?;

let embedding_model = prompt_text("Model used for embeddings, bge-m3 is a solid choice", None)
.prompt()
.unwrap();
let embedding_model =
prompt_text("Model used for embeddings, bge-m3 is a solid choice", None).prompt()?;

let vector_size = inquire::Text::new("Vector size for the embedding model")
.with_validator(|input: &str| match input.parse::<usize>() {
Expand All @@ -197,8 +196,7 @@ fn ollama_questions(context: &mut tera::Context) {
"Invalid number".into(),
)),
})
.prompt()
.unwrap();
.prompt()?;

let base_url = inquire::Text::new("Custom base url? (optional, <esc> to skip)")
.with_validator(|input: &str| match url::Url::parse(input) {
Expand All @@ -207,8 +205,7 @@ fn ollama_questions(context: &mut tera::Context) {
"Invalid URL".into(),
)),
})
.prompt_skippable()
.unwrap();
.prompt_skippable()?;

context.insert(
"llm",
Expand All @@ -228,18 +225,19 @@ fn ollama_questions(context: &mut tera::Context) {
"vector_size": vector_size,
}),
);

Ok(())
}

pub fn fastembed_questions(context: &mut tera::Context) {
pub fn fastembed_questions(context: &mut tera::Context) -> Result<()> {
println!("\nFastEmbed provides embeddings that are generated quickly locally. Unless you have a specific need for a different model, the default is a good choice.");

let embedding_model: FastembedModel = prompt_select(
"Embedding model",
FastembedModel::list_supported_models(),
Some(FastembedModel::default().to_string()),
)
.parse()
.unwrap();
)?
.parse()?;

context.insert(
"embed_llm",
Expand All @@ -249,6 +247,8 @@ pub fn fastembed_questions(context: &mut tera::Context) {
"base_url": None::<String>,
}),
);

Ok(())
}

#[derive(Clone)]
Expand Down
8 changes: 4 additions & 4 deletions src/onboarding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ pub async fn run(file: Option<PathBuf>, dry_run: bool) -> Result<()> {
println!("We have a few questions to ask you to get started, you can always change these later in the `{}` file.", file.display());

let mut context = tera::Context::new();
project_questions(&mut context);
git_questions(&mut context);
llm_questions(&mut context).await;
command_questions(&mut context);
project_questions(&mut context)?;
git_questions(&mut context)?;
llm_questions(&mut context).await?;
command_questions(&mut context)?;

let config =
Templates::render("kwaak.toml", &context).context("Failed to render default config")?;
Expand Down
11 changes: 6 additions & 5 deletions src/onboarding/project.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use anyhow::Result;
use strum::IntoEnumIterator;
use swiftide::integrations::treesitter::SupportedLanguages;

use crate::config::defaults::default_project_name;

use super::util::{prompt_select, prompt_text};

pub fn project_questions(context: &mut tera::Context) {
pub fn project_questions(context: &mut tera::Context) -> Result<()> {
let project_name = default_project_name();
let project_name_input = prompt_text("Project name", Some(&project_name))
.prompt()
.unwrap();
let project_name_input = prompt_text("Project name", Some(&project_name)).prompt()?;
context.insert("project_name", &project_name_input);

// Get user inputs with defaults
Expand All @@ -18,9 +17,11 @@ pub fn project_questions(context: &mut tera::Context) {
.map(|l| l.to_string())
.collect::<Vec<_>>();

let language_input = prompt_select("Programming language", options.clone(), detected);
let language_input = prompt_select("Programming language", options.clone(), detected)?;

context.insert("language", &language_input);

Ok(())
}

fn naive_lang_detect() -> Option<String> {
Expand Down
6 changes: 4 additions & 2 deletions src/onboarding/util.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use anyhow::Result;

pub fn prompt_text<'a>(prompt: &'a str, default: Option<&'a str>) -> inquire::Text<'a> {
let mut prompt = inquire::Text::new(prompt);

Expand Down Expand Up @@ -27,7 +29,7 @@ pub fn prompt_api_key<'a>(prompt: &'a str, default: Option<&'a str>) -> inquire:
}

#[allow(clippy::needless_pass_by_value)]
pub fn prompt_select<T>(prompt: &str, options: Vec<T>, default: Option<T>) -> String
pub fn prompt_select<T>(prompt: &str, options: Vec<T>, default: Option<T>) -> Result<String>
where
T: std::fmt::Display + std::cmp::PartialEq + Clone,
{
Expand All @@ -49,5 +51,5 @@ where
}
}

prompt.prompt().unwrap().to_string()
Ok(prompt.prompt()?.to_string())
}
Loading

0 comments on commit 9d31200

Please sign in to comment.