Skip to content

Commit

Permalink
improve output format
Browse files Browse the repository at this point in the history
  • Loading branch information
ricopinazo committed May 20, 2024
1 parent 0d475a3 commit 38043be
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 39 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ classifiers = [
"Operating System :: Unix",
]
dependencies = [
"typer[all]==0.9.0",
"typer-slim==0.12.3",
"rich==13.7.1",
"langchain==0.1.17",
"langchain-openai==0.1.6",
"ollama==0.1.9",
Expand Down
22 changes: 22 additions & 0 deletions src/comai/animations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from contextlib import contextmanager
from typing import Generator, Iterator
from rich import print
from prompt_toolkit import prompt
from prompt_toolkit.styles import Style

from comai.prompt import prompt_str

LEFT = "\033[D"
CLEAR_LINE = "\033[K"
Expand Down Expand Up @@ -50,6 +54,24 @@ def print_command_token(chunk: str):
print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)


def print_command_prompt(command: str):
# print(CLEAR_LINE, end="", flush=True)
sys.stdout.write(f"\r{CLEAR_LINE}")
style = Style.from_dict(
{
# User input (default text)
"": "ansicyan",
"mark": "ansimagenta",
# "question": "ansiwhite",
}
)
message = [
("class:mark", ANSWER_PROMPT),
]
# return "ls"
return prompt(message, default="%s" % command, style=style)


def hide_cursor() -> None:
if stdout_is_tty:
sys.stdout.write("\033[?25l")
Expand Down
94 changes: 61 additions & 33 deletions src/comai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@
from typing import List, Optional, Iterator
from typing_extensions import Annotated
from langchain_community.llms.ollama import OllamaEndpointNotFoundError
from urllib3.exceptions import NewConnectionError
from requests.exceptions import ConnectionError

from comai import __version__
from comai.chain import StreamStart, Token, FinalCommand, query_command
from comai.ollama import get_ollama_model_names
from comai.settings import load_settings, write_settings, Settings
from comai.prompt import prompt_bool, prompt_options
from comai.settings import (
InvalidSettingsFileException,
load_settings,
write_settings,
Settings,
)
from comai.context import get_context
from comai.menu import get_option_from_menu, MenuOption
from comai.animations import (
Expand All @@ -19,9 +27,11 @@
show_cursor,
hide_cursor,
start_printing_command,
print_command_prompt,
)

app = typer.Typer()
typer.core.rich = None # this is to avoid using right for the help panel
app = typer.Typer(pretty_exceptions_enable=False, add_completion=False)


def version_callback(value: bool):
Expand Down Expand Up @@ -50,23 +60,11 @@ def settings_callback(value: bool):
default_model = ollama_models[0]
else:
default_model = "llama3"
model = click.prompt(
"Ollama model",
type=click.Choice(ollama_models),
default=default_model,
show_default=True,
show_choices=True,
)
verbose = click.prompt(
"Verbose mode",
type=click.BOOL,
default="yes" if settings.verbose else "no",
show_default=True,
show_choices=True,
model = prompt_options(
"Which model do you want to use?", ollama_models, default_model
)
settings.provider = "ollama"
settings.model = model
settings.verbose = verbose
write_settings(settings)
raise typer.Exit()

Expand Down Expand Up @@ -95,35 +93,65 @@ def main_normal_flow(instructions: List[str], settings: Settings):
if final_command is None:
raise Exception("failed to fetch command")

match get_option_from_menu(settings):
case MenuOption.run:
os.system(final_command)
case MenuOption.cancel:
pass
user_command = print_command_prompt(final_command)
os.system(user_command)

# match get_option_from_menu(settings):
# case MenuOption.run:
# os.system(final_command)
# case MenuOption.cancel:
# pass

@app.command()

@app.command(help="Translates natural language instructions into commands")
def main(
instructions: List[str],
version: Annotated[
Optional[bool], typer.Option("--version", callback=version_callback)
] = None,
config: Annotated[
Optional[bool], typer.Option("--config", callback=settings_callback)
Optional[bool],
typer.Option(
"--config",
callback=settings_callback,
help="Starts an interactive menu to select your configuration", # assisted isntead?
),
] = None,
show_config: Annotated[
Optional[bool], typer.Option("--show-config", callback=show_settings_callback)
Optional[bool],
typer.Option(
"--show-config",
callback=show_settings_callback,
help="Show the current configuration",
),
] = None,
version: Annotated[
Optional[bool],
typer.Option(
"--version", callback=version_callback, help="Show the current version"
),
] = None,
):
settings = load_settings()

try:
settings = load_settings()
except InvalidSettingsFileException as e:
message = f"Your settings file at {e.settings_path} is incorrect. Please fix it or start from scratch with comai --config"
sys.stderr.write(message + "\n")
exit(1)
except Exception as e:
raise e

try:
main_normal_flow(instructions, settings)
except OllamaEndpointNotFoundError as e:
sys.stderr.write(
f"Model '{settings.model}' not found in the ollama service. Please download it with 'ollama pull {settings.model}' or select a different model with 'comai --config'"
)
typer.Exit(1)
except OllamaEndpointNotFoundError:
message = f"Model '{settings.model}' not found in the ollama service. Please download it with 'ollama pull {settings.model}' or select a different model with 'comai --config'"
sys.stderr.write(message + "\n")
typer.Exit(2)
except ConnectionError as e:
if settings.provider == "ollama":
message = f"Ollama service is not running. Please install it and run it (https://ollama.com/download) or select a different provider with 'comai --config'"
sys.stderr.write(message + "\n")
typer.Exit(3)
else:
raise e
except Exception as e:
raise e
finally:
Expand Down
46 changes: 46 additions & 0 deletions src/comai/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Literal
from simple_term_menu import TerminalMenu
from rich import print
from prompt_toolkit import prompt
from prompt_toolkit.styles import Style


def prompt_str(question: str, default: str) -> str:
style = Style.from_dict(
{
# User input (default text)
"": "ansicyan",
"mark": "ansicyan",
"question": "ansiwhite",
}
)
message = [
("class:mark", "? "),
("class:question", f"{question.strip()} "),
]
return prompt(message, default="%s" % default, style=style)


def prompt_options(question: str, options: list[str], default: str) -> str:
terminal_menu = TerminalMenu(
options,
title=f"? {question}",
menu_cursor="• ",
menu_cursor_style=("bg_black", "fg_green"),
menu_highlight_style=("bg_black", "fg_green"),
)
index = terminal_menu.show()
answer = options[index]
print(f"[cyan]?[/cyan] {question} [cyan]{answer}[/cyan]")
return answer


def prompt_bool(question: str, default: bool) -> bool:
default_option = "Yes" if default == True else "No"
answer = prompt_options(question, ["Yes", "No"], default_option)
if answer == "Yes":
return True
elif answer == "No":
return False
else:
raise Exception(f"unexpected input {answer}")
18 changes: 13 additions & 5 deletions src/comai/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
from typing import Literal
import typer
from pydantic import BaseModel
Expand All @@ -10,20 +11,27 @@

class Settings(BaseModel):
provider: Literal["ollama", "openai"]
# TODO: improve this, should be typed per provider, although possible models can be queried at runtime
model: str = "llama3"
verbose: bool = True


DEFAULT_SETTINGS: Settings = Settings(provider="ollama")


class InvalidSettingsFileException(BaseException):
def __init__(self, settings_path: str):
self.settings_path = settings_path
super().__init__()


def load_settings() -> Settings:
try:
return Settings.parse_file(settings_path)
except:
# TODO: if there is indeed a file but the file is incorrect, we should complain instead of returning the default
with open(settings_path, "r") as file:
content = file.read()
return Settings.model_validate_json(content)
except FileNotFoundError:
return DEFAULT_SETTINGS
except Exception:
raise InvalidSettingsFileException(settings_path=settings_path)


def write_settings(settings: Settings):
Expand Down

0 comments on commit 38043be

Please sign in to comment.