Skip to content

Commit

Permalink
✨ Add in-memory chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Feb 11, 2024
1 parent 73a6130 commit 0233a39
Showing 1 changed file with 52 additions and 1 deletion.
53 changes: 52 additions & 1 deletion src/funcchain/utils/memory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""langchain_community.chat_message_histories.in_memory.ChatMessageHistory"""
from typing import Any

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from rich import print

from ..schema.types import ChatHistoryFactory


class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history.
Stores messages in an in memory list.
This is a copy from `langchain_community.chat_message_histories.in_memory.ChatMessageHistory`
to not require langchain_community as dependency only for this feature.
"""

messages: list[BaseMessage] = Field(default_factory=list)
Expand All @@ -19,3 +25,48 @@ def add_message(self, message: BaseMessage) -> None:

def clear(self) -> None:
self.messages = []


_in_memory_database: dict[str, list[BaseMessage]] = {}


class InMemoryChatMessageHistory(BaseChatMessageHistory):
"""In memory implementation of chat message history.
Stores messages in an in memory list.
"""

def __init__(self, session_id: str) -> None:
self.session_id = session_id
if session_id not in _in_memory_database:
_in_memory_database[session_id] = []

@property
def messages(self) -> list[BaseMessage]: # type: ignore
return _in_memory_database[self.session_id]

def add_message(self, message: BaseMessage) -> None:
_in_memory_database[self.session_id].append(message)

def add_messages(self, messages: list[BaseMessage]) -> None: # type: ignore
_in_memory_database[self.session_id].extend(messages)

def clear(self) -> None:
print(f"Clearing {self.session_id}")
del _in_memory_database[self.session_id][:]


def create_history_factory(
backend: type[BaseChatMessageHistory],
backend_kwargs: dict[str, Any] = {},
) -> ChatHistoryFactory:
"""
Create a function that returns a chat history.
"""

def history_factory(session_id: str, **kwargs: Any) -> BaseChatMessageHistory:
kwargs["session_id"] = session_id
kwargs.update(backend_kwargs)
return backend(**kwargs)

return history_factory

0 comments on commit 0233a39

Please sign in to comment.