diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index e20f346..47b22e8 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -1,25 +1,23 @@
name: Release new version
-on:
+on:
workflow_dispatch:
inputs:
type:
- description: 'Release type'
+ description: "Release type"
required: true
- default: 'preview'
+ default: "preview"
type: choice
options:
- - major
- - minor
- - fix
- - preview
- - release
+ - fix
+ - minor
+ - minor,preview
+ - major,preview
+ - release
jobs:
test:
uses: ./.github/workflows/tests.yml
- secrets:
- OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
release:
needs: test
runs-on: ubuntu-latest
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 247efc4..28c89bb 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -1,11 +1,8 @@
name: tests
-on:
+on:
push:
workflow_call:
- secrets:
- OPENAI_API_KEY:
- required: true
jobs:
test:
@@ -22,24 +19,10 @@ jobs:
with:
python-version: "3.x"
- name: Install dependencies
- run: python -m pip install -e .[test]
- - name: Set environment variable using secret
- run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV
- - name: Run test suite
- run: pytest -v
- code-checks:
- runs-on: ubuntu-latest
- timeout-minutes: 10
- steps:
- - name: Check out repository code
- uses: actions/checkout@v3
- - name: Setup Python
- uses: actions/setup-python@v4
- with:
- python-version: "3.x"
- - name: Install black
- run: python -m pip install black=="23.*" mypy=="1.*"
+ run: python -m pip install -e .[dev]
- name: Check style
- run: black --check .
+ run: black --check --diff .
- name: chek types
run: mypy .
+ - name: Run test suite
+ run: pytest -v
diff --git a/.gitignore b/.gitignore
index f2058f0..b557946 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,6 @@ __pycache__
.python-version
dist
.env
-.DS_Store
\ No newline at end of file
+.DS_Store
+.ipynb_checkpoints
+experiments
diff --git a/README.md b/README.md
index 973b7b9..3573bcc 100644
--- a/README.md
+++ b/README.md
@@ -1,127 +1,106 @@

-
- **The AI powered terminal assistant**
-
- [](https://github.com/ricopinazo/comai/actions/workflows/tests.yml)
- [](https://github.com/ricopinazo/comai/releases)
- [](https://pypi.org/project/comai/)
- [](https://github.com/ricopinazo/comai/issues)
- [](https://pypi.org/project/comai/)
- [](./LICENSE)
- [](https://github.com/psf/black)
- [](http://mypy-lang.org/)
+
+**The AI powered terminal assistant**
+
+[](https://github.com/ricopinazo/comai/actions/workflows/tests.yml)
+[](https://github.com/ricopinazo/comai/releases)
+[](https://pypi.org/project/comai/)
+[](https://github.com/ricopinazo/comai/issues)
+[](https://pypi.org/project/comai/)
+[](./LICENSE)
+[](https://github.com/psf/black)
+[](http://mypy-lang.org/)
+
## What is comai? 🎯
-`comai` is an open source terminal assistant powered by OpenAI API that enables you to interact with your command line interface using natural language instructions. It simplifies your workflow by converting natural language queries into executable commands. No more memorizing complex syntax. Just chat with `comai` using plain English!
+`comai` is an open source CLI utility that translates natural language instructions into executable commands.
-

+
## Installation 🚀
-Getting `comai` up and running is a breeze. You can simply use [`pip`](https://pip.pypa.io/en/stable/) to install the latest version:
+`comai` is available as a python package. We recommend using [`pipx`](https://pypa.github.io/pipx/) to install it:
```shell
-pip install comai
+pipx install comai
```
-However, if you usually work with python environments, it is recommended to use [`pipx`](https://pypa.github.io/pipx/) instead:
+By default, `comai` is setup to use [ollama](https://ollama.com) under the hood, which allows you to host any popular open source LLM locally. If you are happy with this, make sure to install and have ollama running. You can find the install instructions [here](https://ollama.com/download).
+
+Once installed, make sure to download the `llama3` model, since comai has been optimised for it
```shell
-pipx install comai
+ollama pull llama3
```
-The first time you run it, it'll ask you for an OpenAI API key. You can create a developer account [here](https://platform.openai.com/overview). Once in your account, go to `API Keys` section and `Create new secret key`. We recommend setting a usage limit under `Billing`/`Usage limits`.
+Otherwise, you can set up any other model available in the ollama service via:
+
+```shell
+comai --config
+```
> **_NOTE:_** `comai` uses the environment variable `TERM_SESSION_ID` to maintain context between calls so you don't need to repeat yourself giving instructions to it. You can check if it is available in your default terminal checking the output of `echo $TERM_SESSION_ID`, which should return some type of UUID. If the output is empty, you can simply add the following to your `.zshrc`/`.bashrc` file:
+>
> ```shell
> export TERM_SESSION_ID=$(uuidgen)
> ```
-## Usage Examples 🎉
+## Usage examples 🎉
Using `comai` is straightforward. Simply invoke the `comai` command followed by your natural language instruction. `comai` will provide you with the corresponding executable command, which you can execute by pressing Enter or ignore by pressing any other key.
Let's dive into some exciting examples of how you can harness the power of `comai`:
-1. Manage your system like a pro:
-```shell
-$ comai print my private ip address
-❯ ifconfig | grep "inet " | grep -v 127.0.0.1 | awk '{print $2}'
-192.168.0.2
+1. Access network details:
-$ comai and my public one
-❯ curl ifconfig.me
+```
+$ comai print my public ip address
+❯ curl -s4 ifconfig.co
92.234.58.146
```
-2. Master the intricacies of `git`:
+2. Manage `git` like a pro:
+
```shell
-$ comai squash the last 3 commits into a single commit
-❯ git rebase -i HEAD~3
+$ comai rename the current branch to awesome-branch
+❯ git branch -m $(git rev-parse --abbrev-ref HEAD) awesome-branch
$ comai show me all the branches having commit c4c0d2d in common
-❯ git branch --contains c4c0d2d
- chat-api
- configparser
-* main
+❯ git branch -a --contains c4c0d2d
+ main
+ fix/terrible-bug
+* awesome-branch
```
-3. Check the weather forecast for your location:
-```shell
-$ comai show me the weather forecast
-❯ curl wttr.in
-```
+3. Find the annoying process using the port 8080:
-4. Find the annoying process using the port 8080:
```shell
$ comai show me the process using the port 8080
❯ lsof -i :8080
COMMAND PID USER FD TYPE DEVICE SIZE/OFF NODE NAME
node 36350 pedrorico 18u IPv4 0xe0d28ea918e376b 0t0 TCP *:http-alt (LISTEN)
-
-$ comai show me only the PID
-❯ lsof -t -i :8080
-36350
-
-$ comai kill it
-❯ kill $(lsof -t -i :8080)
```
-5. Swiftly get rid of all your docker containers:
+4. Get rid of all your docker containers:
+
```shell
$ comai stop and remove all running docker containers
❯ docker stop $(docker ps -aq) && docker rm $(docker ps -aq)
```
-These are just a few examples of how `comai` can help you harness the power of the command line and provide you with useful and entertaining commands. Feel free to explore and experiment with the commands generated by `comai` to discover more exciting possibilities!
-
## Contributions welcome! 🤝
If you're interested in joining the development of new features for `comai`, here's all you need to get started:
1. Clone the [repository](https://github.com/ricopinazo/comai) and navigate to the root folder.
2. Install the package in editable mode by running `pip install -e .`.
-3. Run the tests using `pytest`. Make sure you have the `OPENAI_API_KEY` environment variable set up with your OpenAI API key. Alternatively, you can create a file named `.env` and define the variable there.
-
-
-This project utilizes black for code formatting. To ensure your changes adhere to this format, simply follow these steps:
-
-```shell
-pip install black
-black .
-```
-
-For users of VS Code, you can configure the following options after installing `black`:
-
-```json
-"editor.formatOnSave": true,
-"python.formatting.provider": "black"
-```
+3. Run the tests using `pytest`.
+4. Format your code using [black](https://github.com/psf/black) before submitting any change.
## License 📜
diff --git a/demo.gif b/demo.gif
index da28ed0..ea49668 100644
Binary files a/demo.gif and b/demo.gif differ
diff --git a/pyproject.toml b/pyproject.toml
index ff8c1b3..917d6b8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,12 +5,10 @@ build-backend = "hatchling.build"
[project]
name = "comai"
dynamic = ["version"]
-authors = [
- { name="Pedro Rico", email="ricopinazo@gmail.com" },
-]
+authors = [{ name = "Pedro Rico", email = "ricopinazo@gmail.com" }]
description = "AI powered console assistant"
readme = "README.md"
-license = {file = "LICENSE"}
+license = { file = "LICENSE" }
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
@@ -20,9 +18,13 @@ classifiers = [
"Operating System :: Unix",
]
dependencies = [
- "typer[all]==0.9.0",
- "openai==0.27.5",
- "cryptography==40.0.2",
+ "typer-slim==0.12.3",
+ "rich==13.7.1",
+ "prompt-toolkit==3.0.43",
+ "simple-term-menu==1.6.4",
+ "langchain==0.1.17",
+ "langchain-openai==0.1.6",
+ "ollama==0.1.9",
]
[project.urls]
@@ -34,10 +36,12 @@ issues = "https://github.com/ricopinazo/comai/issues"
comai = "comai.cli:app"
[project.optional-dependencies]
-test = [
+dev = [
"pytest",
"hatchling",
- "python-dotenv"
+ "types-requests==2.31.0.20240406",
+ "mypy==1.9.0",
+ "black==24.4.0",
]
[tool.hatch.version]
diff --git a/src/comai/animations.py b/src/comai/animations.py
index 6c0f8c9..3e4ed8f 100644
--- a/src/comai/animations.py
+++ b/src/comai/animations.py
@@ -3,6 +3,10 @@
from contextlib import contextmanager
from typing import Generator, Iterator
from rich import print
+import prompt_toolkit
+from prompt_toolkit.styles import Style
+
+from comai.prompt import prompt_str
LEFT = "\033[D"
CLEAR_LINE = "\033[K"
@@ -42,12 +46,27 @@ def query_animation() -> Generator[None, None, None]:
t.join()
-def print_answer(command_chunks: Iterator[str]):
+def start_printing_command():
print(f"[{ANSWER_PROMPT_COLOR}]{ANSWER_PROMPT}", end="", flush=True)
- first_chunk = next(command_chunks)
- print(f"[{COMMAND_COLOR}]{first_chunk}", end="", flush=True)
- for chunk in command_chunks:
- print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)
+
+
+def print_command_token(chunk: str):
+ print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)
+
+
+def print_command_prompt(command: str):
+ sys.stdout.write(f"\r{CLEAR_LINE}")
+ style = Style.from_dict(
+ {
+ # User input (default text)
+ "": "ansicyan",
+ "mark": "ansimagenta",
+ }
+ )
+ message = [
+ ("class:mark", ANSWER_PROMPT),
+ ]
+ return prompt_toolkit.prompt(message, default="%s" % command, style=style) # type: ignore
def hide_cursor() -> None:
diff --git a/src/comai/chain.py b/src/comai/chain.py
new file mode 100644
index 0000000..43b59bd
--- /dev/null
+++ b/src/comai/chain.py
@@ -0,0 +1,160 @@
+import itertools
+import logging
+import re
+import os
+import textwrap
+import uuid
+from dataclasses import dataclass
+from typing import Iterable, Iterator, List
+from pydantic import BaseModel, SecretStr
+
+from langchain.globals import set_debug, set_verbose
+from langchain_community.chat_message_histories import SQLChatMessageHistory
+from langchain_community.chat_models import ChatOllama
+from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
+from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
+from langchain_core.runnables import RunnableGenerator
+from langchain_core.runnables.base import Runnable
+from langchain_core.runnables.history import (
+ MessagesOrDictWithMessages,
+ RunnableWithMessageHistory,
+)
+from langchain_core.runnables.utils import ConfigurableFieldSpec
+from langchain_core.tools import tool
+from langchain_openai import ChatOpenAI
+
+from comai.context import Context
+from comai.history import load_history
+from comai.settings import Settings
+
+logging.getLogger().setLevel(logging.CRITICAL)
+
+
+def create_prompt(context: Context):
+ system_message = f"""
+ You are a CLI tool called comai with access to a {context.shell} shell on {context.system}. Your goal is to translate the instruction from the user into a command.
+
+ ALWAYS use the following format, with no additional comments, explanation, or notes before or after AT ALL:
+
+ COMMAND END
+
+ Example:
+
+ User:
+ show files
+
+ Your answer:
+ COMMAND ls END
+ """
+
+ return ChatPromptTemplate.from_messages(
+ [
+ ("system", textwrap.dedent(system_message).strip()),
+ MessagesPlaceholder(variable_name="history"),
+ ("human", "{question}"),
+ ]
+ )
+
+
+def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]:
+ input = ""
+ output: str = ""
+ for token in tokens:
+ if type(token.content) == str:
+ input += token.content
+ pattern = r"COMMAND(.*?)(END|$)"
+ match = re.search(pattern, input)
+ if match:
+ updated_output = match.group(1).strip()
+ if len(updated_output) > len(output):
+ yield updated_output[len(output) :]
+ output = updated_output
+
+
+parse_command_generator = RunnableGenerator(parse_command)
+
+
+def attatch_history(
+ runnable: Runnable[
+ MessagesOrDictWithMessages,
+ str | BaseMessage | MessagesOrDictWithMessages,
+ ]
+):
+ return RunnableWithMessageHistory(
+ runnable,
+ lambda session_id: load_history(session_id),
+ input_messages_key="question",
+ history_messages_key="history",
+ )
+
+
+class OpenaiSecrets(BaseModel):
+ default_key: SecretStr | None
+ comai_key: SecretStr | None
+
+
+def extract_secret(key: str | None):
+ if key is None:
+ return None
+ else:
+ return SecretStr(key)
+
+
+def create_chain_stream(settings: Settings, context: Context):
+ prompt = create_prompt(context)
+ if settings.provider == "ollama":
+ model = ChatOllama(model=settings.model, temperature=0)
+ elif settings.provider == "openai":
+ default_key = extract_secret(os.environ.get("OPENAI_API_KEY"))
+ comai_key = extract_secret(os.environ.get("COMAI_OPENAI_API_KEY"))
+ api_key = comai_key if comai_key is not None else default_key
+ model = ChatOpenAI(model=settings.model, temperature=0, api_key=api_key) # type: ignore
+ base_chain = prompt | model
+
+ if context.session_id is not None:
+ session_id = context.session_id
+ else:
+ session_id = str(uuid.uuid4()) # FIXME: should just not use history at all
+
+ chain_with_history = attatch_history(base_chain)
+ runnable = chain_with_history | parse_command_generator
+
+ def stream(input: dict[str, str]):
+ return runnable.stream(
+ input=input, config={"configurable": {"session_id": session_id}}
+ )
+
+ return stream
+
+
+# TODO: move this to different file
+@dataclass
+class StreamStart:
+ pass
+
+
+@dataclass
+class Token:
+ content: str
+
+
+@dataclass
+class FinalCommand:
+ command: str
+
+
+def query_command(
+ query: str, settings: Settings, context: Context
+) -> Iterator[StreamStart | Token | FinalCommand]:
+ stream = create_chain_stream(settings, context)
+ output = stream({"question": query})
+
+ started = False
+ buffer = ""
+ for token in output:
+ if not started:
+ started = True
+ yield StreamStart()
+ yield Token(token)
+ buffer += token
+ yield FinalCommand(buffer)
diff --git a/src/comai/cli.py b/src/comai/cli.py
index 58fe3ef..f165017 100755
--- a/src/comai/cli.py
+++ b/src/comai/cli.py
@@ -1,19 +1,39 @@
import os
+import sys
import typer
+import click
import itertools
-from typing import List, Optional, Iterator
+from typing import List, Optional, Iterator, Literal
from typing_extensions import Annotated
-
-from . import config, context, translation, __version__
-from .menu import get_option_from_menu, MenuOption
-from .animations import (
+from langchain_community.llms.ollama import OllamaEndpointNotFoundError
+from urllib3.exceptions import NewConnectionError
+from requests.exceptions import ConnectionError
+
+from comai import __version__
+from comai.chain import StreamStart, Token, FinalCommand, query_command
+from comai.ollama import get_ollama_model_names
+from comai.openai import get_openai_model_names
+from comai.prompt import prompt_bool, prompt_options
+from comai.settings import (
+ InvalidSettingsFileException,
+ load_settings,
+ write_settings,
+ Settings,
+)
+from comai.context import get_context
+from comai.animations import (
+ print_command_token,
query_animation,
- print_answer,
show_cursor,
hide_cursor,
+ start_printing_command,
+ print_command_prompt,
)
-app = typer.Typer()
+# this is to avoid using rich for the help panel
+typer.core.rich = None # type: ignore
+
+app = typer.Typer(pretty_exceptions_enable=False, add_completion=False)
def version_callback(value: bool):
@@ -22,62 +42,123 @@ def version_callback(value: bool):
raise typer.Exit()
-def save_command(command_chunks, command_backup: list) -> Iterator[str]:
- for chunk in command_chunks:
- command_backup.append(chunk)
- yield chunk
+def show_settings_callback(value: bool):
+ if value:
+ settings = load_settings()
+ print("Current settings:")
+ print(settings.model_dump_json(indent=2))
+ raise typer.Exit()
+
+def settings_callback(value: bool):
+ if value:
+ settings = load_settings()
-def wait_for_first_chunk(iterator: Iterator[str]):
- iter1, iter2 = itertools.tee(iterator)
- _ = next(iter1)
- return iter2
+ provider = prompt_options(
+ "Which provider do you want to use?",
+ ["ollama", "openai"],
+ settings.provider,
+ )
+ if provider == "ollama":
+ models = get_ollama_model_names()
+ elif provider == "openai":
+ models = get_openai_model_names()
+ else:
+ raise Exception(f"Got unknown provider option: {provider}")
+
+ if settings.model in models:
+ default_model = settings.model
+ elif "llama3" in models:
+ default_model = "llama3"
+ elif "gpt-3.5-turbo" in models:
+ default_model = "gpt-3.5-turbo"
+ elif len(models) > 0:
+ default_model = models[0]
+ else:
+ raise Exception("No models available for the selected provider")
+ model = prompt_options("Which model do you want to use?", models, default_model)
+ updated_settings = Settings.parse_obj({"provider": provider, "model": model})
+ write_settings(updated_settings)
+ raise typer.Exit()
-def main_normal_flow(instructions: List[str]):
+def main_normal_flow(instructions: List[str], settings: Settings):
+ final_command: str | None = None
input_text = " ".join(instructions)
- api_key = config.load_api_key()
- if not api_key:
- api_key = typer.prompt("Please enter your OpenAI API key")
- assert len(api_key) > 0
- if not translation.validate_api_key(api_key):
- print("API key not valid")
- exit(1)
- config.save_api_key(api_key)
-
hide_cursor()
- command_chunks: Iterator[str] = iter(())
- command_backup: List[str] = []
+ context = get_context()
+ output = query_command(input_text, settings, context)
+
with query_animation():
- ctx = context.get_context()
- history_path = config.get_history_path()
- command_chunks = translation.translate_with_history(
- input_text, history_path, ctx, api_key
- )
- command_chunks = save_command(command_chunks, command_backup)
- command_chunks = wait_for_first_chunk(command_chunks)
+ stream_start = next(output)
+ assert type(stream_start) == StreamStart
+
+ start_printing_command()
+ for chunk in output:
+ match chunk:
+ case Token(token):
+ print_command_token(token)
+ case FinalCommand(command):
+ final_command = command
- print_answer(command_chunks)
- command: str = "".join(command_backup)
+ if final_command is None:
+ raise Exception("failed to fetch command")
- match get_option_from_menu():
- case MenuOption.run:
- os.system(command)
- case MenuOption.cancel:
- pass
+ user_command = print_command_prompt(final_command)
+ os.system(user_command)
-@app.command()
+@app.command(help="Translates natural language instructions into commands")
def main(
instructions: List[str],
+ config: Annotated[
+ Optional[bool],
+ typer.Option(
+ "--config",
+ callback=settings_callback,
+ help="Starts an interactive menu to select your configuration", # assisted isntead?
+ ),
+ ] = None,
+ show_config: Annotated[
+ Optional[bool],
+ typer.Option(
+ "--show-config",
+ callback=show_settings_callback,
+ help="Show the current configuration",
+ ),
+ ] = None,
version: Annotated[
- Optional[bool], typer.Option("--version", callback=version_callback)
+ Optional[bool],
+ typer.Option(
+ "--version", callback=version_callback, help="Show the current version"
+ ),
] = None,
):
+
+ try:
+ settings = load_settings()
+ except InvalidSettingsFileException as e:
+ message = f"Your settings file at {e.settings_path} is incorrect. Please fix it or start from scratch with comai --config"
+ sys.stderr.write(message + "\n")
+ exit(1)
+ except Exception as e:
+ raise e
+
try:
- main_normal_flow(instructions)
+ main_normal_flow(instructions, settings)
+ except OllamaEndpointNotFoundError:
+ message = f"Model '{settings.model}' not found in the ollama service. Please download it with 'ollama pull {settings.model}' or select a different model with 'comai --config'"
+ sys.stderr.write(message + "\n")
+ typer.Exit(2)
+ except ConnectionError as e:
+ if settings.provider == "ollama":
+ message = f"Ollama service is not running. Please install it and run it (https://ollama.com/download) or select a different provider with 'comai --config'"
+ sys.stderr.write(message + "\n")
+ typer.Exit(3)
+ else:
+ raise e
except Exception as e:
raise e
finally:
diff --git a/src/comai/config.py b/src/comai/config.py
deleted file mode 100644
index 3d15686..0000000
--- a/src/comai/config.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import os
-import configparser
-from pathlib import Path
-import typer
-import tempfile
-from typing import Optional
-from cryptography.fernet import Fernet
-
-CONTEXT_SIZE = 20
-APP_NAME = "comai"
-config_dir = typer.get_app_dir(APP_NAME, force_posix=True)
-key_path = os.path.join(config_dir, "config.ini")
-temp_dir = tempfile.gettempdir()
-session_id = os.getenv("TERM_SESSION_ID")
-log_path: Optional[os.PathLike] = None
-if session_id:
- try:
- log_path = Path(os.path.join(temp_dir, session_id))
- except Exception:
- pass
-
-encryption_key = b"QUMSqTJ5nape3p8joqkgHFCzyJdyQtqzHk6dCuGl9Nw="
-cipher_suite = Fernet(encryption_key)
-
-
-def save_api_key(api_key):
- encrypted_key = cipher_suite.encrypt(api_key.encode())
- config = configparser.ConfigParser()
- config["DEFAULT"] = {"api_key": encrypted_key.decode()}
-
- os.makedirs(config_dir, mode=0o700, exist_ok=True)
-
- def opener(path, flags):
- return os.open(path, flags, 0o600)
-
- with open(key_path, "w", opener=opener) as configfile:
- config.write(configfile)
-
-
-def load_api_key():
- try:
- config = configparser.ConfigParser()
- config.read(key_path)
- encrypted_key = config["DEFAULT"]["api_key"].encode()
- decrypted_key = cipher_suite.decrypt(encrypted_key)
- return decrypted_key.decode()
- except Exception:
- return None
-
-
-def delete_api_key():
- if os.path.isfile(key_path):
- os.remove(key_path)
-
-
-def get_history_path() -> Optional[os.PathLike]:
- return log_path
diff --git a/src/comai/context.py b/src/comai/context.py
index fddae26..f8c0b59 100644
--- a/src/comai/context.py
+++ b/src/comai/context.py
@@ -2,6 +2,8 @@
import sys
from dataclasses import dataclass
+session_id = os.getenv("TERM_SESSION_ID")
+
shell = os.getenv("SHELL")
if not shell:
shell = "bash"
@@ -17,6 +19,7 @@
class Context:
system: str
shell: str
+ session_id: str | None
def get_context() -> Context:
@@ -31,4 +34,4 @@ def get_context() -> Context:
shell = "bash"
shell = shell.split("/")[-1]
- return Context(system, shell)
+ return Context(system, shell, session_id)
diff --git a/src/comai/history.py b/src/comai/history.py
index 71977f0..ed3b7d6 100644
--- a/src/comai/history.py
+++ b/src/comai/history.py
@@ -1,46 +1,24 @@
-from __future__ import annotations
-
import os
-import pickle
-from copy import copy
-from typing import Optional, List
-
-
-class History:
- def __init__(self, messages: List, filepath: Optional[os.PathLike] = None) -> None:
- self.messages = messages
- self.filepath = filepath
-
- @classmethod
- def load_from_file(cls, filepath: os.PathLike) -> History:
- messages = []
- try:
- with open(filepath, "br") as history_file:
- messages = pickle.load(history_file)
- except Exception:
- pass
- return History(messages, filepath)
-
- @classmethod
- def create_local(cls) -> History:
- return History([], None)
-
- def append_user_message(self, request: str) -> None:
- user_message = {"role": "user", "content": request}
- self.messages += [user_message]
-
- def append_assistant_message(self, command: str) -> None:
- content = f"""COMMAND {command} END"""
- assistant_message = {"role": "user", "content": content}
- self.messages += [assistant_message]
-
- def get_messages(self) -> List:
- return self.messages
-
- def checkpoint(self) -> None:
- if self.filepath:
- with open(self.filepath, "bw") as history_file:
- pickle.dump(self.messages, history_file)
-
- def copy(self) -> History:
- return History(copy(self.messages), self.filepath)
+import tempfile
+from pathlib import Path
+
+from langchain_community.chat_message_histories import SQLChatMessageHistory
+from langchain_core.chat_history import InMemoryChatMessageHistory
+
+temp_dir = tempfile.gettempdir()
+session_id = os.getenv("TERM_SESSION_ID")
+history_path: os.PathLike | None = None
+if session_id:
+ try:
+ history_path = Path(os.path.join(temp_dir, session_id))
+ except Exception:
+ pass
+
+
+def load_history(session_id: str) -> SQLChatMessageHistory | InMemoryChatMessageHistory:
+ if history_path:
+ return SQLChatMessageHistory(
+ session_id=session_id, connection_string=f"sqlite:///{history_path}"
+ )
+ else:
+ return InMemoryChatMessageHistory()
diff --git a/src/comai/menu.py b/src/comai/menu.py
deleted file mode 100644
index 32d8b99..0000000
--- a/src/comai/menu.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import click
-from enum import Enum
-from .animations import show_cursor
-from rich import print
-from rich.markup import escape
-
-
-class MenuOption(str, Enum):
- run = "r"
- cancel = "c"
-
-
-DEFAULT_OPTION = escape("[r]")
-MENU_PROMPT = f"[bright_black] ➜ [underline bold]r[/underline bold]un | [underline bold]c[/underline bold]ancel {DEFAULT_OPTION}:[/bright_black]"
-
-
-def get_option_from_menu() -> MenuOption:
- print(MENU_PROMPT, end="", flush=True)
- show_cursor()
- option = click.prompt(
- "",
- prompt_suffix="",
- type=MenuOption,
- default=MenuOption.run,
- show_default=False,
- show_choices=False,
- )
- return option
diff --git a/src/comai/ollama.py b/src/comai/ollama.py
new file mode 100644
index 0000000..8a1d99a
--- /dev/null
+++ b/src/comai/ollama.py
@@ -0,0 +1,8 @@
+import ollama
+
+
+def get_ollama_model_names() -> list[str]:
+ models = ollama.list()["models"]
+ model_names = [model["name"] for model in models]
+ cleaned_model_names = [name.replace(":latest", "") for name in model_names]
+ return cleaned_model_names
diff --git a/src/comai/openai.py b/src/comai/openai.py
new file mode 100644
index 0000000..64e053d
--- /dev/null
+++ b/src/comai/openai.py
@@ -0,0 +1,12 @@
+from openai import OpenAI
+import os
+
+
+def get_openai_model_names() -> list[str]:
+ # FIXME: this piece of code is duplicated
+ default_key = os.environ.get("OPENAI_API_KEY")
+ comai_key = os.environ.get("COMAI_OPENAI_API_KEY")
+ api_key = comai_key if comai_key is not None else default_key
+
+ client = OpenAI(api_key=api_key)
+ return [model.id for model in client.models.list() if model.id.startswith("gpt-")]
diff --git a/src/comai/prompt.py b/src/comai/prompt.py
new file mode 100644
index 0000000..6807cf3
--- /dev/null
+++ b/src/comai/prompt.py
@@ -0,0 +1,46 @@
+from typing import Literal
+from simple_term_menu import TerminalMenu
+from rich import print
+import prompt_toolkit
+from prompt_toolkit.styles import Style
+
+
+def prompt_str(question: str, default: str) -> str:
+ style = Style.from_dict(
+ {
+ # User input (default text)
+ "": "ansicyan",
+ "mark": "ansicyan",
+ "question": "ansiwhite",
+ }
+ )
+ message = [
+ ("class:mark", "? "),
+ ("class:question", f"{question.strip()} "),
+ ]
+ return prompt_toolkit.prompt(message, default="%s" % default, style=style) # type: ignore
+
+
+def prompt_options(question: str, options: list[str], default: str) -> str:
+ terminal_menu = TerminalMenu(
+ options,
+ title=f"? {question}",
+ menu_cursor="• ",
+ menu_cursor_style=("bg_black", "fg_green"),
+ menu_highlight_style=("bg_black", "fg_green"),
+ )
+ index = terminal_menu.show()
+ answer = options[index]
+ print(f"[cyan]?[/cyan] {question} [cyan]{answer}[/cyan]")
+ return answer
+
+
+def prompt_bool(question: str, default: bool) -> bool:
+ default_option = "Yes" if default == True else "No"
+ answer = prompt_options(question, ["Yes", "No"], default_option)
+ if answer == "Yes":
+ return True
+ elif answer == "No":
+ return False
+ else:
+ raise Exception(f"unexpected input {answer}")
diff --git a/src/comai/settings.py b/src/comai/settings.py
new file mode 100644
index 0000000..e5bc93b
--- /dev/null
+++ b/src/comai/settings.py
@@ -0,0 +1,40 @@
+import os
+import sys
+from typing import Literal
+import typer
+from pydantic import BaseModel
+
+APP_NAME = "comai"
+config_dir = typer.get_app_dir(APP_NAME, force_posix=True)
+settings_path = os.path.join(config_dir, "settings.json")
+
+
+class Settings(BaseModel):
+ provider: Literal["ollama", "openai"]
+ model: str = "llama3"
+
+
+DEFAULT_SETTINGS: Settings = Settings(provider="ollama")
+
+
+class InvalidSettingsFileException(BaseException):
+ def __init__(self, settings_path: str):
+ self.settings_path = settings_path
+ super().__init__()
+
+
+def load_settings() -> Settings:
+ try:
+ with open(settings_path, "r") as file:
+ content = file.read()
+ return Settings.model_validate_json(content)
+ except FileNotFoundError:
+ return DEFAULT_SETTINGS
+ except Exception:
+ raise InvalidSettingsFileException(settings_path=settings_path)
+
+
+def write_settings(settings: Settings):
+ json = settings.model_dump_json(indent=2)
+ with open(settings_path, "w") as file:
+ file.write(json + "\n")
diff --git a/src/comai/translation.py b/src/comai/translation.py
deleted file mode 100644
index 1360b64..0000000
--- a/src/comai/translation.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import os
-import openai
-from typing import Iterator, Optional
-from .context import Context
-from .history import History
-
-
-class CommandMissingException(Exception):
- pass
-
-
-def validate_api_key(openai_api_key) -> bool:
- try:
- openai.Model.list(api_key=openai_api_key)
- return True
- except Exception:
- return False
-
-
-def translate_with_history(
- instruction: str,
- history_path: Optional[os.PathLike],
- context: Context,
- openai_api_key: str,
-) -> Iterator[str]:
- history: History = History.create_local()
- if history_path:
- history = History.load_from_file(history_path)
- history.append_user_message(instruction)
-
- commands_chunks = []
- chunks = request_command(history, context, openai_api_key)
- try:
- for chunk in filter_assistant_message(chunks):
- yield chunk
- commands_chunks.append(chunk)
- except CommandMissingException:
- corrective_history = history.copy()
- corrective_history.append_user_message("stick to the format")
- chunks = request_command(corrective_history, context, openai_api_key)
- for chunk in filter_assistant_message(chunks):
- yield chunk
- commands_chunks.append(chunk)
-
- command = "".join(commands_chunks)
- history.append_assistant_message(command)
- history.checkpoint()
-
-
-def filter_assistant_message(chunks: Iterator[str]) -> Iterator[str]:
- # Filter all the chunks between COMMAND and END
- try:
- while "COMMAND" not in next(chunks):
- pass
-
- first_chunk = next(chunks)
- yield first_chunk[1:] # removes the space after "COMMAND"
-
- while "END" not in (chunk := next(chunks)):
- yield chunk
- except StopIteration:
- raise CommandMissingException
- return
-
-
-def request_command(
- current_history: History, context: Context, openai_api_key: str
-) -> Iterator[str]:
- openai.api_key = openai_api_key
-
- system_prompt = system_prompt_from_context(context)
- system_message = {"role": "system", "content": system_prompt}
- chat_messages = current_history.get_messages()
- messages = [system_message] + chat_messages
-
- response = openai.ChatCompletion.create(
- model="gpt-3.5-turbo",
- max_tokens=200,
- n=1,
- stop=None,
- temperature=0,
- stream=True,
- messages=messages,
- )
-
- for chunk in response:
- if "content" in chunk.choices[0].delta:
- yield chunk.choices[0].delta.content
-
-
-def system_prompt_from_context(context: Context) -> str:
- return f"""
- You are a CLI tool called comai with access to a {context.shell} shell on {context.system}. Your goal is to translate the instruction from the user into a command.
-
- ALWAYS use the following format, with no additional comments, explanation, or notes before or after AT ALL:
-
- COMMAND END
-
- Example:
-
- User:
- show files
-
- Your answer:
- COMMAND ls END
- """
diff --git a/tests/test_comai.py b/tests/test_comai.py
index 3ab3338..7809081 100644
--- a/tests/test_comai.py
+++ b/tests/test_comai.py
@@ -1,34 +1,28 @@
-import os
+import prompt_toolkit
+from typing import Any
from typer.testing import CliRunner
-from dotenv import load_dotenv
-from comai import cli, config, translation, context, __version__
-from comai.history import History
-load_dotenv()
-api_key = os.getenv("OPENAI_API_KEY")
+from langchain_core.messages import AIMessageChunk
+from langchain_community.chat_models import ChatOllama
-runner = CliRunner()
+from comai import cli, __version__
+runner = CliRunner()
-def test_invalid_api_key():
- config.delete_api_key()
- result = runner.invoke(cli.app, ["show", "files"], input="bad-api-key\n")
- assert result.exit_code != 0
- assert "API key not valid" in result.stdout
+def test_normal_flow(monkeypatch):
+ def mock_stream(*args, **kwargs):
+ for token in ["COMMAND", " ls", " END"]:
+ yield AIMessageChunk(content=token)
+ def mock_prompt(message: str, default: str, style: Any = None):
+ return default
-def test_installation_flow():
- config.delete_api_key()
+ monkeypatch.setattr(ChatOllama, "stream", mock_stream)
+ monkeypatch.setattr(prompt_toolkit, "prompt", mock_prompt)
- result = runner.invoke(cli.app, ["show", "files"], input=f"{api_key}\n\n")
+ result = runner.invoke(cli.app, ["show", "files"])
assert result.exit_code == 0
- assert "Please enter your OpenAI API key:" in result.stdout
- assert "ls" in result.stdout
-
- result = runner.invoke(cli.app, ["show", "files"], input="\n")
- assert result.exit_code == 0
- assert "Please enter your OpenAI API key:" not in result.stdout
assert "ls" in result.stdout
@@ -41,13 +35,3 @@ def test_version():
def test_missing_instruction():
result = runner.invoke(cli.app, [])
assert result.exit_code != 0
-
-
-def test_translation():
- ctx = context.get_context()
- history = History.create_local()
- history.append_user_message("show files")
-
- answer = translation.request_command(history, ctx, api_key)
- command = translation.filter_assistant_message(answer)
- assert "".join(command) == "ls"