Skip to content

Commit

Permalink
initial version for langchain migration
Browse files Browse the repository at this point in the history
  • Loading branch information
ricopinazo committed May 6, 2024
1 parent e4e5566 commit 9e3e9cb
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 264 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__
.python-version
dist
.env
.DS_Store
.DS_Store
.ipynb_checkpoints
18 changes: 6 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ build-backend = "hatchling.build"
[project]
name = "comai"
dynamic = ["version"]
authors = [
{ name="Pedro Rico", email="[email protected]" },
]
authors = [{ name = "Pedro Rico", email = "[email protected]" }]
description = "AI powered console assistant"
readme = "README.md"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
Expand All @@ -20,9 +18,9 @@ classifiers = [
"Operating System :: Unix",
]
dependencies = [
"typer[all]==0.9.0",
"openai==0.27.5",
"cryptography==40.0.2",
"typer[all]==0.9.0",
"langchain==0.1.17",
"langchain-openai==0.1.6",
]

[project.urls]
Expand All @@ -34,11 +32,7 @@ issues = "https://github.com/ricopinazo/comai/issues"
comai = "comai.cli:app"

[project.optional-dependencies]
test = [
"pytest",
"hatchling",
"python-dotenv"
]
test = ["pytest", "hatchling", "python-dotenv"]

[tool.hatch.version]
path = "src/comai/__init__.py"
Expand Down
16 changes: 11 additions & 5 deletions src/comai/animations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,18 @@ def query_animation() -> Generator[None, None, None]:
t.join()


def print_answer(command_chunks: Iterator[str]):
# def print_answer(command_chunks: Iterator[str]):
# print(f"[{ANSWER_PROMPT_COLOR}]{ANSWER_PROMPT}", end="", flush=True)
# first_chunk = next(command_chunks)
# print(f"[{COMMAND_COLOR}]{first_chunk}", end="", flush=True)
# for chunk in command_chunks:
# print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)

def start_printing_command():
print(f"[{ANSWER_PROMPT_COLOR}]{ANSWER_PROMPT}", end="", flush=True)
first_chunk = next(command_chunks)
print(f"[{COMMAND_COLOR}]{first_chunk}", end="", flush=True)
for chunk in command_chunks:
print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)

def print_command_token(chunk: str):
print(f"[{COMMAND_COLOR}]{chunk}", end="", flush=True)


def hide_cursor() -> None:
Expand Down
128 changes: 128 additions & 0 deletions src/comai/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.tools import tool
from langchain.globals import set_debug, set_verbose
from langchain_community.chat_models import ChatOllama
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.runnables import RunnableGenerator
from typing import Iterable, Iterator, List
import re
import logging
import uuid
import itertools
import textwrap
from dataclasses import dataclass

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import MessagesOrDictWithMessages, RunnableWithMessageHistory
from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_core.runnables.base import Runnable
from langchain_core.messages import BaseMessage

from comai.context import Context
from comai.history import load_history
from comai.settings import Settings

logging.getLogger().setLevel(logging.CRITICAL)


def create_prompt(context: Context):
system_message = f"""
You are a CLI tool called comai with access to a {context.shell} shell on {context.system}. Your goal is to translate the instruction from the user into a command.
ALWAYS use the following format, with no additional comments, explanation, or notes before or after AT ALL:
COMMAND <the command> END
Example:
User:
show files
Your answer:
COMMAND ls END
"""

return ChatPromptTemplate.from_messages(
[
("system", textwrap.dedent(system_message).strip()),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)

def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]:
input = ""
output: str = ""
for token in tokens:
if type(token.content) == str:
input += token.content
pattern = r"COMMAND(.*?)(END|$)"
match = re.search(pattern, input)
if match:
updated_output = match.group(1).strip()
if len(updated_output) > len(output):
yield updated_output[len(output):]
output = updated_output

parse_command_generator = RunnableGenerator(parse_command)


def attatch_history(runnable: Runnable[
MessagesOrDictWithMessages,
str | BaseMessage | MessagesOrDictWithMessages,
]):
return RunnableWithMessageHistory(
runnable,
lambda session_id: load_history(session_id),
input_messages_key="question",
history_messages_key="history",
)


def create_chain_stream(settings: Settings, context: Context):
prompt = create_prompt(context)
model = ChatOllama(model=settings.model)
base_chain = prompt | model

if context.session_id is not None:
session_id = context.session_id
else:
session_id = str(uuid.uuid4()) # FIXME: should just not use history at all

chain_with_history = attatch_history(base_chain)
runnable = chain_with_history | parse_command_generator

def stream(input: dict[str, str]):
return runnable.stream(
input=input,
config={"configurable": {"session_id": session_id}}
)

return stream

# TODO: move this to different file
@dataclass
class StreamStart:
pass

@dataclass
class Token:
content: str

@dataclass
class FinalCommand:
command: str

def query_command(query: str, settings: Settings, context: Context) -> Iterator[StreamStart | Token | FinalCommand]:
stream = create_chain_stream(settings, context)
output = stream({"question": "print your name"})

started = False
buffer = ""
for token in output:
if not started:
started = True
yield StreamStart()
yield Token(token)
buffer += token
yield FinalCommand(buffer)
62 changes: 25 additions & 37 deletions src/comai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from typing import List, Optional, Iterator
from typing_extensions import Annotated

from . import config, context, translation, __version__
from .menu import get_option_from_menu, MenuOption
from .animations import (
from comai import config, __version__
from comai.chain import StreamStart, Token, FinalCommand, query_command
from comai.settings import load_settings
from comai.context import get_context
from comai.menu import get_option_from_menu, MenuOption
from comai.animations import (
print_command_token,
query_animation,
print_answer,
show_cursor,
hide_cursor,
start_printing_command,
)

app = typer.Typer()
Expand All @@ -22,49 +26,33 @@ def version_callback(value: bool):
raise typer.Exit()


def save_command(command_chunks, command_backup: list) -> Iterator[str]:
for chunk in command_chunks:
command_backup.append(chunk)
yield chunk


def wait_for_first_chunk(iterator: Iterator[str]):
iter1, iter2 = itertools.tee(iterator)
_ = next(iter1)
return iter2


def main_normal_flow(instructions: List[str]):
final_command: str | None = None
input_text = " ".join(instructions)

api_key = config.load_api_key()
if not api_key:
api_key = typer.prompt("Please enter your OpenAI API key")
assert len(api_key) > 0
if not translation.validate_api_key(api_key):
print("API key not valid")
exit(1)
config.save_api_key(api_key)

hide_cursor()

command_chunks: Iterator[str] = iter(())
command_backup: List[str] = []
settings = load_settings()
context = get_context()
output = query_command(input_text, settings, context)
with query_animation():
ctx = context.get_context()
history_path = config.get_history_path()
command_chunks = translation.translate_with_history(
input_text, history_path, ctx, api_key
)
command_chunks = save_command(command_chunks, command_backup)
command_chunks = wait_for_first_chunk(command_chunks)
stream_start = next(output)
assert(type(stream_start) == StreamStart)

start_printing_command()
for chunk in output:
match chunk:
case Token(token):
print_command_token(token)
case FinalCommand(command):
final_command = command

print_answer(command_chunks)
command: str = "".join(command_backup)
if final_command is None:
raise Exception("failed to fetch command")

match get_option_from_menu():
case MenuOption.run:
os.system(command)
os.system(final_command)
case MenuOption.cancel:
pass

Expand Down
57 changes: 0 additions & 57 deletions src/comai/config.py

This file was deleted.

5 changes: 4 additions & 1 deletion src/comai/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
from dataclasses import dataclass

session_id = os.getenv("TERM_SESSION_ID")

shell = os.getenv("SHELL")
if not shell:
shell = "bash"
Expand All @@ -17,6 +19,7 @@
class Context:
system: str
shell: str
session_id: str | None


def get_context() -> Context:
Expand All @@ -31,4 +34,4 @@ def get_context() -> Context:
shell = "bash"
shell = shell.split("/")[-1]

return Context(system, shell)
return Context(system, shell, session_id)
Loading

0 comments on commit 9e3e9cb

Please sign in to comment.