Skip to content

Commit

Permalink
format and sort some imports
Browse files Browse the repository at this point in the history
  • Loading branch information
ricopinazo committed May 6, 2024
1 parent 9e3e9cb commit f3b7892
Showing 1 changed file with 33 additions and 22 deletions.
55 changes: 33 additions & 22 deletions src/comai/chain.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.tools import tool
from langchain.globals import set_debug, set_verbose
from langchain_community.chat_models import ChatOllama
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.runnables import RunnableGenerator
from typing import Iterable, Iterator, List
import re
import logging
import uuid
import itertools
import logging
import re
import textwrap
import uuid
from dataclasses import dataclass
from typing import Iterable, Iterator, List

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import MessagesOrDictWithMessages, RunnableWithMessageHistory
from langchain.globals import set_debug, set_verbose
from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_community.chat_models import ChatOllama
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableGenerator
from langchain_core.runnables.base import Runnable
from langchain_core.messages import BaseMessage
from langchain_core.runnables.history import (
MessagesOrDictWithMessages,
RunnableWithMessageHistory,
)
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.tools import tool

from comai.context import Context
from comai.history import load_history
Expand Down Expand Up @@ -50,6 +52,7 @@ def create_prompt(context: Context):
]
)


def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]:
input = ""
output: str = ""
Expand All @@ -61,16 +64,19 @@ def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]:
if match:
updated_output = match.group(1).strip()
if len(updated_output) > len(output):
yield updated_output[len(output):]
yield updated_output[len(output) :]
output = updated_output


parse_command_generator = RunnableGenerator(parse_command)


def attatch_history(runnable: Runnable[
MessagesOrDictWithMessages,
str | BaseMessage | MessagesOrDictWithMessages,
]):
def attatch_history(
runnable: Runnable[
MessagesOrDictWithMessages,
str | BaseMessage | MessagesOrDictWithMessages,
]
):
return RunnableWithMessageHistory(
runnable,
lambda session_id: load_history(session_id),
Expand All @@ -87,33 +93,38 @@ def create_chain_stream(settings: Settings, context: Context):
if context.session_id is not None:
session_id = context.session_id
else:
session_id = str(uuid.uuid4()) # FIXME: should just not use history at all
session_id = str(uuid.uuid4()) # FIXME: should just not use history at all

chain_with_history = attatch_history(base_chain)
runnable = chain_with_history | parse_command_generator

def stream(input: dict[str, str]):
return runnable.stream(
input=input,
config={"configurable": {"session_id": session_id}}
input=input, config={"configurable": {"session_id": session_id}}
)

return stream


# TODO: move this to different file
@dataclass
class StreamStart:
pass


@dataclass
class Token:
content: str


@dataclass
class FinalCommand:
command: str

def query_command(query: str, settings: Settings, context: Context) -> Iterator[StreamStart | Token | FinalCommand]:

def query_command(
query: str, settings: Settings, context: Context
) -> Iterator[StreamStart | Token | FinalCommand]:
stream = create_chain_stream(settings, context)
output = stream({"question": "print your name"})

Expand Down

0 comments on commit f3b7892

Please sign in to comment.