From f3b78923bd069a7559f7a348573e464d1db60e54 Mon Sep 17 00:00:00 2001 From: Pedro Rico Pinazo Date: Mon, 6 May 2024 22:32:45 +0100 Subject: [PATCH] format and sort some imports --- src/comai/chain.py | 55 +++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/src/comai/chain.py b/src/comai/chain.py index 0001c5b..5357343 100644 --- a/src/comai/chain.py +++ b/src/comai/chain.py @@ -1,22 +1,24 @@ -from langchain_core.runnables.utils import ConfigurableFieldSpec -from langchain_core.tools import tool -from langchain.globals import set_debug, set_verbose -from langchain_community.chat_models import ChatOllama -from langchain_core.messages import AIMessage, AIMessageChunk -from langchain_core.runnables import RunnableGenerator -from typing import Iterable, Iterator, List -import re -import logging -import uuid import itertools +import logging +import re import textwrap +import uuid from dataclasses import dataclass +from typing import Iterable, Iterator, List -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables.history import MessagesOrDictWithMessages, RunnableWithMessageHistory +from langchain.globals import set_debug, set_verbose from langchain_community.chat_message_histories import SQLChatMessageHistory +from langchain_community.chat_models import ChatOllama +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableGenerator from langchain_core.runnables.base import Runnable -from langchain_core.messages import BaseMessage +from langchain_core.runnables.history import ( + MessagesOrDictWithMessages, + RunnableWithMessageHistory, +) +from langchain_core.runnables.utils import ConfigurableFieldSpec +from langchain_core.tools import tool from comai.context import Context from comai.history import load_history @@ -50,6 +52,7 @@ def create_prompt(context: Context): ] ) + def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]: input = "" output: str = "" @@ -61,16 +64,19 @@ def parse_command(tokens: Iterable[AIMessageChunk]) -> Iterator[str]: if match: updated_output = match.group(1).strip() if len(updated_output) > len(output): - yield updated_output[len(output):] + yield updated_output[len(output) :] output = updated_output + parse_command_generator = RunnableGenerator(parse_command) -def attatch_history(runnable: Runnable[ - MessagesOrDictWithMessages, - str | BaseMessage | MessagesOrDictWithMessages, -]): +def attatch_history( + runnable: Runnable[ + MessagesOrDictWithMessages, + str | BaseMessage | MessagesOrDictWithMessages, + ] +): return RunnableWithMessageHistory( runnable, lambda session_id: load_history(session_id), @@ -87,33 +93,38 @@ def create_chain_stream(settings: Settings, context: Context): if context.session_id is not None: session_id = context.session_id else: - session_id = str(uuid.uuid4()) # FIXME: should just not use history at all + session_id = str(uuid.uuid4()) # FIXME: should just not use history at all chain_with_history = attatch_history(base_chain) runnable = chain_with_history | parse_command_generator def stream(input: dict[str, str]): return runnable.stream( - input=input, - config={"configurable": {"session_id": session_id}} + input=input, config={"configurable": {"session_id": session_id}} ) return stream + # TODO: move this to different file @dataclass class StreamStart: pass + @dataclass class Token: content: str + @dataclass class FinalCommand: command: str -def query_command(query: str, settings: Settings, context: Context) -> Iterator[StreamStart | Token | FinalCommand]: + +def query_command( + query: str, settings: Settings, context: Context +) -> Iterator[StreamStart | Token | FinalCommand]: stream = create_chain_stream(settings, context) output = stream({"question": "print your name"})