From f0b388fb3ab3cce1f3d4f49d2deee25a716d1cea Mon Sep 17 00:00:00 2001 From: Shroominic Date: Sat, 15 Jun 2024 14:51:58 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20still=20wip:=20chat=20handler=20?= =?UTF-8?q?factory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/funcchain/syntax/components/handler.py | 42 +++++++++------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/funcchain/syntax/components/handler.py b/src/funcchain/syntax/components/handler.py index 68f0375..4540a4f 100644 --- a/src/funcchain/syntax/components/handler.py +++ b/src/funcchain/syntax/components/handler.py @@ -1,12 +1,11 @@ +from typing import Any + from langchain_core.language_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables.history import RunnableWithMessageHistory from ...backend.settings import create_local_settings from ...model.defaults import univeral_model_selector -from ...schema.types import ChatHandler, OptionalChatHistoryFactory, UniversalChatModel -from ...utils.memory import InMemoryChatMessageHistory, create_history_factory -from ...utils.msg_tools import msg_to_str +from ...schema.types import ChatRunnable, UniversalChatModel def load_universal_llm(llm: UniversalChatModel) -> BaseChatModel: @@ -21,28 +20,21 @@ def load_universal_llm(llm: UniversalChatModel) -> BaseChatModel: def create_chat_handler( *, llm: UniversalChatModel = None, - history_factory: OptionalChatHistoryFactory = None, - system_message: str = "", -) -> ChatHandler: - history_factory = history_factory or create_history_factory(InMemoryChatMessageHistory) - llm = load_universal_llm(llm) - - chat_handler_chain = ( - ChatPromptTemplate.from_messages( + system_message: str | None, + tools: list[str] = [], + vision: bool = False, + read_files: bool = False, + read_links: bool = False, + code_interpreter: bool = False, + **kwargs: Any, +) -> ChatRunnable: + return ( + {"messages": lambda x: x} + | ChatPromptTemplate.from_messages( [ - *([("system", system_message)] if system_message else []), # todo test this - MessagesPlaceholder(variable_name="history"), - ("human", "{message}"), + *([("system", system_message)] if system_message else []), + MessagesPlaceholder(variable_name="messages"), ] ) - | llm - ) - return { - # todo handle images - "message": lambda x: msg_to_str(x), - } | RunnableWithMessageHistory( - chat_handler_chain, # type: ignore - get_session_history=history_factory, - input_messages_key="message", - history_messages_key="history", + | load_universal_llm(llm) # type: ignore )