diff --git a/src/funcchain/syntax/components/handler.py b/src/funcchain/syntax/components/handler.py index e6e8ef5..68f0375 100644 --- a/src/funcchain/syntax/components/handler.py +++ b/src/funcchain/syntax/components/handler.py @@ -1,63 +1,48 @@ -from typing import Union - -from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables import Runnable from langchain_core.runnables.history import RunnableWithMessageHistory -from ...backend.settings import settings +from ...backend.settings import create_local_settings from ...model.defaults import univeral_model_selector +from ...schema.types import ChatHandler, OptionalChatHistoryFactory, UniversalChatModel +from ...utils.memory import InMemoryChatMessageHistory, create_history_factory from ...utils.msg_tools import msg_to_str -UniversalLLM = Union[BaseChatModel, str, None] - -def load_universal_llm(llm: UniversalLLM) -> BaseChatModel: +def load_universal_llm(llm: UniversalChatModel) -> BaseChatModel: if isinstance(llm, str): - settings.llm = llm + settings = create_local_settings({"llm": llm}) llm = None if not llm: llm = univeral_model_selector(settings) return llm -# def history_handler(input: Iterator[Any]) -> Iterator[Any]: - -# for chunk in input: -# yield chunk - - -def BasicChatHandler( +def create_chat_handler( *, - llm: UniversalLLM = None, - chat_history: BaseChatMessageHistory | None = None, + llm: UniversalChatModel = None, + history_factory: OptionalChatHistoryFactory = None, system_message: str = "", -) -> Runnable[HumanMessage, AIMessage]: - if chat_history is None: - from ...utils.memory import ChatMessageHistory - - chat_history = ChatMessageHistory() - +) -> ChatHandler: + history_factory = history_factory or create_history_factory(InMemoryChatMessageHistory) llm = load_universal_llm(llm) - handler_chain = ( + chat_handler_chain = ( ChatPromptTemplate.from_messages( [ - *(("system", system_message) if system_message else []), + *([("system", system_message)] if system_message else []), # todo test this MessagesPlaceholder(variable_name="history"), - ("human", "{user_msg}"), + ("human", "{message}"), ] ) | llm ) return { # todo handle images - "user_msg": lambda x: msg_to_str(x), + "message": lambda x: msg_to_str(x), } | RunnableWithMessageHistory( - handler_chain, # type: ignore - get_session_history=lambda _: chat_history, - input_messages_key="user_msg", + chat_handler_chain, # type: ignore + get_session_history=history_factory, + input_messages_key="message", history_messages_key="history", )