From 9e3e9cbcf92cc006d0ca94a2aa53b13ee276702d Mon Sep 17 00:00:00 2001 From: Pedro Rico Pinazo Date: Mon, 6 May 2024 22:20:29 +0100 Subject: [PATCH] initial version for langchain migration --- .gitignore | 3 +- pyproject.toml | 18 ++---- src/comai/animations.py | 16 +++-- src/comai/chain.py | 128 +++++++++++++++++++++++++++++++++++++++ src/comai/cli.py | 62 ++++++++----------- src/comai/config.py | 57 ----------------- src/comai/context.py | 5 +- src/comai/history.py | 109 +++++++++++++++++++-------------- src/comai/settings.py | 21 +++++++ src/comai/translation.py | 106 -------------------------------- 10 files changed, 261 insertions(+), 264 deletions(-) create mode 100644 src/comai/chain.py delete mode 100644 src/comai/config.py create mode 100644 src/comai/settings.py delete mode 100644 src/comai/translation.py diff --git a/.gitignore b/.gitignore index f2058f0..41872d9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__ .python-version dist .env -.DS_Store \ No newline at end of file +.DS_Store +.ipynb_checkpoints diff --git a/pyproject.toml b/pyproject.toml index ff8c1b3..be5d8b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,12 +5,10 @@ build-backend = "hatchling.build" [project] name = "comai" dynamic = ["version"] -authors = [ - { name="Pedro Rico", email="ricopinazo@gmail.com" }, -] +authors = [{ name = "Pedro Rico", email = "ricopinazo@gmail.com" }] description = "AI powered console assistant" readme = "README.md" -license = {file = "LICENSE"} +license = { file = "LICENSE" } requires-python = ">=3.7" classifiers = [ "Programming Language :: Python :: 3", @@ -20,9 +18,9 @@ classifiers = [ "Operating System :: Unix", ] dependencies = [ - "typer[all]==0.9.0", - "openai==0.27.5", - "cryptography==40.0.2", + "typer[all]==0.9.0", + "langchain==0.1.17", + "langchain-openai==0.1.6", ] [project.urls] @@ -34,11 +32,7 @@ issues = "https://github.com/ricopinazo/comai/issues" comai = "comai.cli:app" [project.optional-dependencies] -test = [ - "pytest", - "hatchling", - "python-dotenv" -] +test = ["pytest", "hatchling", "python-dotenv"] [tool.hatch.version] path = "src/comai/__init__.py" diff --git a/src/comai/animations.py b/src/comai/animations.py index 6c0f8c9..04482ff 100644 --- a/src/comai/animations.py +++ b/src/comai/animations.py @@ -42,12 +42,18 @@ def query_animation() -> Generator[None, None, None]: t.join() -def print_answer(command_chunks: Iterator[str]): +# def print_answer(command_chunks: Iterator[str]): +# print(f"[{ANSWER_PROMPT_COLOR}]{ANSWER_PROMPT}", end="", flush=True) +# first_chunk = next(command_chunks) +# print(f"[{COMMAND_COLOR}]{first_chunk}", end="", flush=True) +# for chunk in command_chunks: +# print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True) + +def start_printing_command(): print(f"[{ANSWER_PROMPT_COLOR}]{ANSWER_PROMPT}", end="", flush=True) - first_chunk = next(command_chunks) - print(f"[{COMMAND_COLOR}]{first_chunk}", end="", flush=True) - for chunk in command_chunks: - print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True) + +def print_command_token(chunk: str): + print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True) def hide_cursor() -> None: diff --git a/src/comai/chain.py b/src/comai/chain.py new file mode 100644 index 0000000..0001c5b --- /dev/null +++ b/src/comai/chain.py @@ -0,0 +1,128 @@ +from langchain_core.runnables.utils import ConfigurableFieldSpec +from langchain_core.tools import tool +from langchain.globals import set_debug, set_verbose +from langchain_community.chat_models import ChatOllama +from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.runnables import RunnableGenerator +from typing import Iterable, Iterator, List +import re +import logging +import uuid +import itertools +import textwrap +from dataclasses import dataclass + +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables.history import MessagesOrDictWithMessages, RunnableWithMessageHistory +from langchain_community.chat_message_histories import SQLChatMessageHistory +from langchain_core.runnables.base import Runnable +from langchain_core.messages import BaseMessage + +from comai.context import Context +from comai.history import load_history +from comai.settings import Settings + +logging.getLogger().setLevel(logging.CRITICAL) + + +def create_prompt(context: Context): + system_message = f""" + You are a CLI tool called comai with access to a {context.shell} shell on {context.system}. Your goal is to translate the instruction from the user into a command. + + ALWAYS use the following format, with no additional comments, explanation, or notes before or after AT ALL: + + COMMAND END + + Example: + + User: + show files + + Your answer: + COMMAND ls END + """ + + return ChatPromptTemplate.from_messages( + [ + ("system", textwrap.dedent(system_message).strip()), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ] + ) + +def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]: + input = "" + output: str = "" + for token in tokens: + if type(token.content) == str: + input += token.content + pattern = r"COMMAND(.*?)(END|$)" + match = re.search(pattern, input) + if match: + updated_output = match.group(1).strip() + if len(updated_output) > len(output): + yield updated_output[len(output):] + output = updated_output + +parse_command_generator = RunnableGenerator(parse_command) + + +def attatch_history(runnable: Runnable[ + MessagesOrDictWithMessages, + str | BaseMessage | MessagesOrDictWithMessages, +]): + return RunnableWithMessageHistory( + runnable, + lambda session_id: load_history(session_id), + input_messages_key="question", + history_messages_key="history", + ) + + +def create_chain_stream(settings: Settings, context: Context): + prompt = create_prompt(context) + model = ChatOllama(model=settings.model) + base_chain = prompt | model + + if context.session_id is not None: + session_id = context.session_id + else: + session_id = str(uuid.uuid4()) # FIXME: should just not use history at all + + chain_with_history = attatch_history(base_chain) + runnable = chain_with_history | parse_command_generator + + def stream(input: dict[str, str]): + return runnable.stream( + input=input, + config={"configurable": {"session_id": session_id}} + ) + + return stream + +# TODO: move this to different file +@dataclass +class StreamStart: + pass + +@dataclass +class Token: + content: str + +@dataclass +class FinalCommand: + command: str + +def query_command(query: str, settings: Settings, context: Context) -> Iterator[StreamStart | Token | FinalCommand]: + stream = create_chain_stream(settings, context) + output = stream({"question": "print your name"}) + + started = False + buffer = "" + for token in output: + if not started: + started = True + yield StreamStart() + yield Token(token) + buffer += token + yield FinalCommand(buffer) diff --git a/src/comai/cli.py b/src/comai/cli.py index 58fe3ef..79e7681 100755 --- a/src/comai/cli.py +++ b/src/comai/cli.py @@ -4,13 +4,17 @@ from typing import List, Optional, Iterator from typing_extensions import Annotated -from . import config, context, translation, __version__ -from .menu import get_option_from_menu, MenuOption -from .animations import ( +from comai import config, __version__ +from comai.chain import StreamStart, Token, FinalCommand, query_command +from comai.settings import load_settings +from comai.context import get_context +from comai.menu import get_option_from_menu, MenuOption +from comai.animations import ( + print_command_token, query_animation, - print_answer, show_cursor, hide_cursor, + start_printing_command, ) app = typer.Typer() @@ -22,49 +26,33 @@ def version_callback(value: bool): raise typer.Exit() -def save_command(command_chunks, command_backup: list) -> Iterator[str]: - for chunk in command_chunks: - command_backup.append(chunk) - yield chunk - - -def wait_for_first_chunk(iterator: Iterator[str]): - iter1, iter2 = itertools.tee(iterator) - _ = next(iter1) - return iter2 - - def main_normal_flow(instructions: List[str]): + final_command: str | None = None input_text = " ".join(instructions) - api_key = config.load_api_key() - if not api_key: - api_key = typer.prompt("Please enter your OpenAI API key") - assert len(api_key) > 0 - if not translation.validate_api_key(api_key): - print("API key not valid") - exit(1) - config.save_api_key(api_key) - hide_cursor() - command_chunks: Iterator[str] = iter(()) - command_backup: List[str] = [] + settings = load_settings() + context = get_context() + output = query_command(input_text, settings, context) with query_animation(): - ctx = context.get_context() - history_path = config.get_history_path() - command_chunks = translation.translate_with_history( - input_text, history_path, ctx, api_key - ) - command_chunks = save_command(command_chunks, command_backup) - command_chunks = wait_for_first_chunk(command_chunks) + stream_start = next(output) + assert(type(stream_start) == StreamStart) + + start_printing_command() + for chunk in output: + match chunk: + case Token(token): + print_command_token(token) + case FinalCommand(command): + final_command = command - print_answer(command_chunks) - command: str = "".join(command_backup) + if final_command is None: + raise Exception("failed to fetch command") match get_option_from_menu(): case MenuOption.run: - os.system(command) + os.system(final_command) case MenuOption.cancel: pass diff --git a/src/comai/config.py b/src/comai/config.py deleted file mode 100644 index 3d15686..0000000 --- a/src/comai/config.py +++ /dev/null @@ -1,57 +0,0 @@ -import os -import configparser -from pathlib import Path -import typer -import tempfile -from typing import Optional -from cryptography.fernet import Fernet - -CONTEXT_SIZE = 20 -APP_NAME = "comai" -config_dir = typer.get_app_dir(APP_NAME, force_posix=True) -key_path = os.path.join(config_dir, "config.ini") -temp_dir = tempfile.gettempdir() -session_id = os.getenv("TERM_SESSION_ID") -log_path: Optional[os.PathLike] = None -if session_id: - try: - log_path = Path(os.path.join(temp_dir, session_id)) - except Exception: - pass - -encryption_key = b"QUMSqTJ5nape3p8joqkgHFCzyJdyQtqzHk6dCuGl9Nw=" -cipher_suite = Fernet(encryption_key) - - -def save_api_key(api_key): - encrypted_key = cipher_suite.encrypt(api_key.encode()) - config = configparser.ConfigParser() - config["DEFAULT"] = {"api_key": encrypted_key.decode()} - - os.makedirs(config_dir, mode=0o700, exist_ok=True) - - def opener(path, flags): - return os.open(path, flags, 0o600) - - with open(key_path, "w", opener=opener) as configfile: - config.write(configfile) - - -def load_api_key(): - try: - config = configparser.ConfigParser() - config.read(key_path) - encrypted_key = config["DEFAULT"]["api_key"].encode() - decrypted_key = cipher_suite.decrypt(encrypted_key) - return decrypted_key.decode() - except Exception: - return None - - -def delete_api_key(): - if os.path.isfile(key_path): - os.remove(key_path) - - -def get_history_path() -> Optional[os.PathLike]: - return log_path diff --git a/src/comai/context.py b/src/comai/context.py index fddae26..f8c0b59 100644 --- a/src/comai/context.py +++ b/src/comai/context.py @@ -2,6 +2,8 @@ import sys from dataclasses import dataclass +session_id = os.getenv("TERM_SESSION_ID") + shell = os.getenv("SHELL") if not shell: shell = "bash" @@ -17,6 +19,7 @@ class Context: system: str shell: str + session_id: str | None def get_context() -> Context: @@ -31,4 +34,4 @@ def get_context() -> Context: shell = "bash" shell = shell.split("/")[-1] - return Context(system, shell) + return Context(system, shell, session_id) diff --git a/src/comai/history.py b/src/comai/history.py index 71977f0..b1b15a6 100644 --- a/src/comai/history.py +++ b/src/comai/history.py @@ -1,46 +1,65 @@ -from __future__ import annotations - import os -import pickle -from copy import copy -from typing import Optional, List - - -class History: - def __init__(self, messages: List, filepath: Optional[os.PathLike] = None) -> None: - self.messages = messages - self.filepath = filepath - - @classmethod - def load_from_file(cls, filepath: os.PathLike) -> History: - messages = [] - try: - with open(filepath, "br") as history_file: - messages = pickle.load(history_file) - except Exception: - pass - return History(messages, filepath) - - @classmethod - def create_local(cls) -> History: - return History([], None) - - def append_user_message(self, request: str) -> None: - user_message = {"role": "user", "content": request} - self.messages += [user_message] - - def append_assistant_message(self, command: str) -> None: - content = f"""COMMAND {command} END""" - assistant_message = {"role": "user", "content": content} - self.messages += [assistant_message] - - def get_messages(self) -> List: - return self.messages - - def checkpoint(self) -> None: - if self.filepath: - with open(self.filepath, "bw") as history_file: - pickle.dump(self.messages, history_file) - - def copy(self) -> History: - return History(copy(self.messages), self.filepath) +import tempfile +from pathlib import Path + +from langchain_community.chat_message_histories import SQLChatMessageHistory +# from langchain.memory import ChatMessageHistory +from langchain_core.chat_history import InMemoryChatMessageHistory + +temp_dir = tempfile.gettempdir() +session_id = os.getenv("TERM_SESSION_ID") +history_path: os.PathLike | None = None +if session_id: + try: + history_path = Path(os.path.join(temp_dir, session_id)) + except Exception: + pass + +def load_history(session_id: str)-> SQLChatMessageHistory | InMemoryChatMessageHistory: + if history_path: + return SQLChatMessageHistory( + session_id=session_id, connection_string=f"sqlite:///{history_path}" + ) + else: + # return ChatMessageHistory() + return InMemoryChatMessageHistory() + + +# class History: +# def __init__(self, messages: List, filepath: Optional[os.PathLike] = None) -> None: +# self.messages = messages +# self.filepath = filepath + + # @classmethod + # def load_from_file(cls, filepath: os.PathLike) -> History: + # messages = [] + # try: + # with open(filepath, "br") as history_file: + # messages = pickle.load(history_file) + # except Exception: + # pass + # return History(messages, filepath) + + # @classmethod + # def create_local(cls) -> History: + # return History([], None) + + # def append_user_message(self, request: str) -> None: + # user_message = {"role": "user", "content": request} + # self.messages += [user_message] + + # def append_assistant_message(self, command: str) -> None: + # content = f"""COMMAND {command} END""" + # assistant_message = {"role": "user", "content": content} + # self.messages += [assistant_message] + + # def get_messages(self) -> List: + # return self.messages + + # def checkpoint(self) -> None: + # if self.filepath: + # with open(self.filepath, "bw") as history_file: + # pickle.dump(self.messages, history_file) + + # def copy(self) -> History: + # return History(copy(self.messages), self.filepath) diff --git a/src/comai/settings.py b/src/comai/settings.py new file mode 100644 index 0000000..e259272 --- /dev/null +++ b/src/comai/settings.py @@ -0,0 +1,21 @@ +import os +from typing import Literal +import typer +from pydantic import BaseModel + +APP_NAME = "comai" +config_dir = typer.get_app_dir(APP_NAME, force_posix=True) +settings_path = os.path.join(config_dir, "settings.json") + +class Settings(BaseModel): + provider: Literal["ollama", "openai"] + model: str = "llama3" # TODO: improve this, should be typed per provider, although possible models can be queried at runtime + +DEFAULT_SETTINGS: Settings = Settings(provider="ollama", model="llama3") + +def load_settings() -> Settings: + try: + return Settings.parse_file(settings_path) + except: + # TODO: if there is indeed a file but the file is incorrect, we should complain instead of returning the default + return DEFAULT_SETTINGS diff --git a/src/comai/translation.py b/src/comai/translation.py deleted file mode 100644 index 1360b64..0000000 --- a/src/comai/translation.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import openai -from typing import Iterator, Optional -from .context import Context -from .history import History - - -class CommandMissingException(Exception): - pass - - -def validate_api_key(openai_api_key) -> bool: - try: - openai.Model.list(api_key=openai_api_key) - return True - except Exception: - return False - - -def translate_with_history( - instruction: str, - history_path: Optional[os.PathLike], - context: Context, - openai_api_key: str, -) -> Iterator[str]: - history: History = History.create_local() - if history_path: - history = History.load_from_file(history_path) - history.append_user_message(instruction) - - commands_chunks = [] - chunks = request_command(history, context, openai_api_key) - try: - for chunk in filter_assistant_message(chunks): - yield chunk - commands_chunks.append(chunk) - except CommandMissingException: - corrective_history = history.copy() - corrective_history.append_user_message("stick to the format") - chunks = request_command(corrective_history, context, openai_api_key) - for chunk in filter_assistant_message(chunks): - yield chunk - commands_chunks.append(chunk) - - command = "".join(commands_chunks) - history.append_assistant_message(command) - history.checkpoint() - - -def filter_assistant_message(chunks: Iterator[str]) -> Iterator[str]: - # Filter all the chunks between COMMAND and END - try: - while "COMMAND" not in next(chunks): - pass - - first_chunk = next(chunks) - yield first_chunk[1:] # removes the space after "COMMAND" - - while "END" not in (chunk := next(chunks)): - yield chunk - except StopIteration: - raise CommandMissingException - return - - -def request_command( - current_history: History, context: Context, openai_api_key: str -) -> Iterator[str]: - openai.api_key = openai_api_key - - system_prompt = system_prompt_from_context(context) - system_message = {"role": "system", "content": system_prompt} - chat_messages = current_history.get_messages() - messages = [system_message] + chat_messages - - response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - max_tokens=200, - n=1, - stop=None, - temperature=0, - stream=True, - messages=messages, - ) - - for chunk in response: - if "content" in chunk.choices[0].delta: - yield chunk.choices[0].delta.content - - -def system_prompt_from_context(context: Context) -> str: - return f""" - You are a CLI tool called comai with access to a {context.shell} shell on {context.system}. Your goal is to translate the instruction from the user into a command. - - ALWAYS use the following format, with no additional comments, explanation, or notes before or after AT ALL: - - COMMAND END - - Example: - - User: - show files - - Your answer: - COMMAND ls END - """