-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial version for langchain migration
- Loading branch information
1 parent
e4e5566
commit 9e3e9cb
Showing
10 changed files
with
261 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ __pycache__ | |
.python-version | ||
dist | ||
.env | ||
.DS_Store | ||
.DS_Store | ||
.ipynb_checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,10 @@ build-backend = "hatchling.build" | |
[project] | ||
name = "comai" | ||
dynamic = ["version"] | ||
authors = [ | ||
{ name="Pedro Rico", email="[email protected]" }, | ||
] | ||
authors = [{ name = "Pedro Rico", email = "[email protected]" }] | ||
description = "AI powered console assistant" | ||
readme = "README.md" | ||
license = {file = "LICENSE"} | ||
license = { file = "LICENSE" } | ||
requires-python = ">=3.7" | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
|
@@ -20,9 +18,9 @@ classifiers = [ | |
"Operating System :: Unix", | ||
] | ||
dependencies = [ | ||
"typer[all]==0.9.0", | ||
"openai==0.27.5", | ||
"cryptography==40.0.2", | ||
"typer[all]==0.9.0", | ||
"langchain==0.1.17", | ||
"langchain-openai==0.1.6", | ||
] | ||
|
||
[project.urls] | ||
|
@@ -34,11 +32,7 @@ issues = "https://github.com/ricopinazo/comai/issues" | |
comai = "comai.cli:app" | ||
|
||
[project.optional-dependencies] | ||
test = [ | ||
"pytest", | ||
"hatchling", | ||
"python-dotenv" | ||
] | ||
test = ["pytest", "hatchling", "python-dotenv"] | ||
|
||
[tool.hatch.version] | ||
path = "src/comai/__init__.py" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from langchain_core.runnables.utils import ConfigurableFieldSpec | ||
from langchain_core.tools import tool | ||
from langchain.globals import set_debug, set_verbose | ||
from langchain_community.chat_models import ChatOllama | ||
from langchain_core.messages import AIMessage, AIMessageChunk | ||
from langchain_core.runnables import RunnableGenerator | ||
from typing import Iterable, Iterator, List | ||
import re | ||
import logging | ||
import uuid | ||
import itertools | ||
import textwrap | ||
from dataclasses import dataclass | ||
|
||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain_core.runnables.history import MessagesOrDictWithMessages, RunnableWithMessageHistory | ||
from langchain_community.chat_message_histories import SQLChatMessageHistory | ||
from langchain_core.runnables.base import Runnable | ||
from langchain_core.messages import BaseMessage | ||
|
||
from comai.context import Context | ||
from comai.history import load_history | ||
from comai.settings import Settings | ||
|
||
logging.getLogger().setLevel(logging.CRITICAL) | ||
|
||
|
||
def create_prompt(context: Context): | ||
system_message = f""" | ||
You are a CLI tool called comai with access to a {context.shell} shell on {context.system}. Your goal is to translate the instruction from the user into a command. | ||
ALWAYS use the following format, with no additional comments, explanation, or notes before or after AT ALL: | ||
COMMAND <the command> END | ||
Example: | ||
User: | ||
show files | ||
Your answer: | ||
COMMAND ls END | ||
""" | ||
|
||
return ChatPromptTemplate.from_messages( | ||
[ | ||
("system", textwrap.dedent(system_message).strip()), | ||
MessagesPlaceholder(variable_name="history"), | ||
("human", "{question}"), | ||
] | ||
) | ||
|
||
def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]: | ||
input = "" | ||
output: str = "" | ||
for token in tokens: | ||
if type(token.content) == str: | ||
input += token.content | ||
pattern = r"COMMAND(.*?)(END|$)" | ||
match = re.search(pattern, input) | ||
if match: | ||
updated_output = match.group(1).strip() | ||
if len(updated_output) > len(output): | ||
yield updated_output[len(output):] | ||
output = updated_output | ||
|
||
parse_command_generator = RunnableGenerator(parse_command) | ||
|
||
|
||
def attatch_history(runnable: Runnable[ | ||
MessagesOrDictWithMessages, | ||
str | BaseMessage | MessagesOrDictWithMessages, | ||
]): | ||
return RunnableWithMessageHistory( | ||
runnable, | ||
lambda session_id: load_history(session_id), | ||
input_messages_key="question", | ||
history_messages_key="history", | ||
) | ||
|
||
|
||
def create_chain_stream(settings: Settings, context: Context): | ||
prompt = create_prompt(context) | ||
model = ChatOllama(model=settings.model) | ||
base_chain = prompt | model | ||
|
||
if context.session_id is not None: | ||
session_id = context.session_id | ||
else: | ||
session_id = str(uuid.uuid4()) # FIXME: should just not use history at all | ||
|
||
chain_with_history = attatch_history(base_chain) | ||
runnable = chain_with_history | parse_command_generator | ||
|
||
def stream(input: dict[str, str]): | ||
return runnable.stream( | ||
input=input, | ||
config={"configurable": {"session_id": session_id}} | ||
) | ||
|
||
return stream | ||
|
||
# TODO: move this to different file | ||
@dataclass | ||
class StreamStart: | ||
pass | ||
|
||
@dataclass | ||
class Token: | ||
content: str | ||
|
||
@dataclass | ||
class FinalCommand: | ||
command: str | ||
|
||
def query_command(query: str, settings: Settings, context: Context) -> Iterator[StreamStart | Token | FinalCommand]: | ||
stream = create_chain_stream(settings, context) | ||
output = stream({"question": "print your name"}) | ||
|
||
started = False | ||
buffer = "" | ||
for token in output: | ||
if not started: | ||
started = True | ||
yield StreamStart() | ||
yield Token(token) | ||
buffer += token | ||
yield FinalCommand(buffer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.