From 94e1cb10640107cba2d192e4eb60e9e180ed0783 Mon Sep 17 00:00:00 2001 From: Shroominic Date: Tue, 14 Nov 2023 19:08:15 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20chat=20memory=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- funcchain/chain.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/funcchain/chain.py b/funcchain/chain.py index 0ca33c8..4882bd5 100644 --- a/funcchain/chain.py +++ b/funcchain/chain.py @@ -6,6 +6,7 @@ from langchain.output_parsers.openai_functions import PydanticOutputFunctionsParser from langchain.prompts import ChatPromptTemplate from langchain.schema import AIMessage, BaseMessage, BaseOutputParser, HumanMessage +from langchain.schema.chat_history import BaseChatMessageHistory from langchain.schema.runnable import RunnableSequence, RunnableWithFallbacks from PIL import Image from pydantic.v1 import BaseModel @@ -101,6 +102,7 @@ def create_chain( system: str = settings.DEFAULT_SYSTEM_PROMPT, parser: BaseOutputParser[T] | None = None, context: list[BaseMessage] = [], + memory: BaseChatMessageHistory | None = None, input_kwargs: dict[str, str] = {}, ) -> RunnableSequence[dict[str, str], T]: output_type = get_output_type() @@ -122,6 +124,10 @@ def create_chain( elif images: raise RuntimeError("Images as input are only supported for vision models.") + if memory: + memory.add_user_message(instruction) + context = memory.messages + context + prompt = create_prompt(instruction, system, context, images=images, **input_kwargs) if func_model: @@ -160,12 +166,13 @@ def chain( system: str = settings.DEFAULT_SYSTEM_PROMPT, parser: BaseOutputParser[T] | None = None, context: list[BaseMessage] = [], + memory: BaseChatMessageHistory | None = None, **input_kwargs: str, ) -> T: # type: ignore """ Get response from chatgpt for provided instructions. """ - chain = create_chain(instruction, system, parser, context, input_kwargs) + chain = create_chain(instruction, system, parser, context, memory, input_kwargs) with get_openai_callback() as cb: result = chain.invoke(input_kwargs) @@ -174,6 +181,9 @@ def chain( f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}" ) + if memory and isinstance(result, str): + memory.add_ai_message(result) + return result @@ -183,12 +193,13 @@ async def achain( system: str = settings.DEFAULT_SYSTEM_PROMPT, parser: BaseOutputParser[T] | None = None, context: list[BaseMessage] = [], + memory: BaseChatMessageHistory | None = None, **input_kwargs: str, ) -> T: """ Get response from chatgpt for provided instructions. """ - chain = create_chain(instruction, system, parser, context, input_kwargs) + chain = create_chain(instruction, system, parser, context, memory, input_kwargs) with get_openai_callback() as cb: result = await chain.ainvoke(input_kwargs) @@ -197,4 +208,7 @@ async def achain( f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}" ) + if memory and isinstance(result, str): + memory.add_ai_message(result) + return result