Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long Term Memory #3

Merged
merged 5 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ LANGCHAIN_TRACING_V2=true
LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
LANGCHAIN_API_KEY=...
LANGCHAIN_PROJECT="bitsgpt-rewrite"
POSTGRES_DB=
POSTGRES_USER=
POSTGRES_PASS=
POSTGRES_HOST=
POSTGRES_PORT=
```

Replace the keys with the appropriate values.
Expand Down
957 changes: 947 additions & 10 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ langsmith = "^0.2.7"
langchain-groq = "^0.2.2"
langgraph-cli = {extras = ["inmem"], version = "^0.1.65"}
pre-commit = "^4.0.1"
psycopg2-binary = "^2.9.10"
langchain-core = "^0.3.29"
langchain = "^0.3.14"


[build-system]
Expand Down
27 changes: 21 additions & 6 deletions src/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import textwrap

from dotenv import load_dotenv
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

from src.tools.memory_tool import tool_modify_memory

load_dotenv()


Expand All @@ -26,7 +29,7 @@ def __init__(self):
self.prompts[agent_name] = f.read()

def get_prompt(
self, agent_name: str, query: str, chat_history: str, agent_scratchpad=False
self, agent_name: str, user_input: str, agent_scratchpad=False
) -> ChatPromptTemplate:

prompt = [
Expand All @@ -36,17 +39,18 @@ def get_prompt(
),
(
"user",
textwrap.dedent(
f"<query>{query}</query>\n\n<history>{chat_history}</history>"
),
textwrap.dedent(user_input),
),
]
if agent_scratchpad:
prompt.append(("placeholder", "{agent_scratchpad}"))
return ChatPromptTemplate.from_messages(prompt)

def intent_classifier(self, query: str, chat_history: str) -> str:
prompt = self.get_prompt("INTENT_CLASSIFIER_AGENT", query, chat_history)
prompt = self.get_prompt(
"INTENT_CLASSIFIER_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)

chain = prompt | self.llm

Expand All @@ -59,7 +63,10 @@ def intent_classifier(self, query: str, chat_history: str) -> str:
return result

def general_campus_query(self, query: str, chat_history: str) -> str:
prompt = self.get_prompt("GENERAL_CAMPUS_QUERY_AGENT", query, chat_history)
prompt = self.get_prompt(
"GENERAL_CAMPUS_QUERY_AGENT",
f"<query>{query}</query>\n\n<history>{chat_history}</history>",
)

chain = prompt | self.llm

Expand All @@ -73,3 +80,11 @@ def general_campus_query(self, query: str, chat_history: str) -> str:

def course_query(self, query: str, chat_history: str) -> str:
raise NotImplementedError("Course query not implemented yet")

def long_term_memory(self, id: str, query: str, memories: str) -> str:
tools = [tool_modify_memory]
prompt = self.get_prompt("LONG_TERM_MEMORY_AGENT", query, agent_scratchpad=True)
agent = create_tool_calling_agent(self.llm, tools, prompt)
chain = AgentExecutor(agent=agent, tools=tools)
result = chain.invoke({"user_id": id, "memories": memories})
return result["output"]
22 changes: 18 additions & 4 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
course_query,
general_campus_query,
intent_classifier,
long_term_memory,
not_related_query,
)
from .state import State
Expand All @@ -21,6 +22,7 @@ def create_graph(self) -> StateGraph:
graph.add_node("course_query", course_query)
graph.add_node("general_campus_query", general_campus_query)
graph.add_node("not_related_query", not_related_query)
graph.add_node("long_term_memory", long_term_memory)

graph.set_entry_point("intent_classifer")

Expand All @@ -34,8 +36,20 @@ def intent_router(state):

graph.add_conditional_edges("intent_classifer", intent_router)

graph.add_edge("course_query", END)
graph.add_edge("general_campus_query", END)
graph.add_edge("not_related_query", END)

graph.add_edge(
"course_query",
"long_term_memory",
)
graph.add_edge(
"general_campus_query",
"long_term_memory",
)
graph.add_edge(
"not_related_query",
"long_term_memory",
)
graph.add_edge(
"long_term_memory",
END,
)
return graph
39 changes: 39 additions & 0 deletions src/memory/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from enum import Enum
from typing import Optional

from langchain.pydantic_v1 import BaseModel, Field


class Category(str, Enum):
Course_Like = "Course Likes"
Course_Dislike = "Course Dislikes"
Branch = "Branch"
Clubs = "Clubs"
Person_Attribute = "Person Attributes"


class Action(str, Enum):
Create = "Create"
Update = "Update"
Delete = "Delete"


class AddMemory(BaseModel):
id: str = Field(..., description="The ID of the user")
memory: str = Field(
...,
description="Condensed bit of knowledge to be saved for future reference in the format: [fact to store] (e.g. Likes Thermodynamics; Dislikes General Biology; Branch is Computer Science; Part of CRUx the best club on campus; Interseted in Math, etc.)",
)
memory_old: Optional[str] = Field(
None,
description="If updating or deleting memory record, the complete, exact phrase that needs to be modified",
)
category: Category = Field(..., description="The category of the memory")
action: Action = Field(
...,
description="Whether this memory is adding a new record, updating a record, or deleting a record",
)


def parse_memory(memory: AddMemory):
raise NotImplementedError("This function is not yet implemented")
13 changes: 13 additions & 0 deletions src/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,16 @@ def not_related_query(state: State):
"I'm sorry, I don't understand the question, if it relates to campus please rephrase."
)
return {"messages": [result]}


def long_term_memory(state: State):
query = state["messages"][0].content
user_id = state["user_id"]
# parse long term memory here.
long_term_memories = state.get("long_term_memories", "")
result = agents.long_term_memory(
user_id,
query,
long_term_memories,
)
return {"messages": [result]}
37 changes: 37 additions & 0 deletions src/prompts/LONG_TERM_MEMORY_AGENT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
You are a supervisor managing a team of knowledge experts.

Your team's job is to create a perfect knowledge base about a users - {user_id} college campus related likes, dislikes habits etc. to assist in highly personalized interactions with the user.

The knowledge base should ultimately consist of many discrete pieces of information that add up to a rich persona (e.g. I like the course CSF111; I hate General Bilogy; I am part of CRUx the coding club of the campus; Pursuing a Computer Science degree; Hate eating in the mess; etc).

Every time you receive a message, you will evaluate if it has any information worth recording in the knowledge base.

A message may contain multiple pieces of information that should be saved separately.

The users id is {user_id}.

You are only interested in the following categories of information:

1. Course prefereces or likes - The user likes or is interested in a course.
2. Course dislikes - The user dislikes a course.
3. Branch - The branch the user is pursuing in college including majors and minors.
4. Clubs - The clubs the user is part of on campus.
5. Personal attributes - Any personal information that the user provides. (e.g. Campus eating habits, Campus sports, Fests etc.). Keep this limited to the context of the campus.

When you receive a message, you perform a sequence of steps consisting of:

1. Analyze the most recent Human message for information. You will see multiple messages for context, but we are only looking for new information in the most recent message.
2. Compare this to the knowledge you already have.
3. Determine if this is new knowledge, an update to old knowledge that now needs to change, or should result in deleting information that is not correct. It's possible that a product/brand you previously wrote as a dislike might now be a like, and other cases- those examples would require an update.
4. Never save the same information twice. If you see the same information in a message that you have already saved, ignore it.
5. Refer to the history for existing memories.

Here are the existing bits of information that we have about the user.

{memories}

Call the right tools to save the information, then respond with DONE. If you identiy multiple pieces of information, call everything at once. You only have one chance to call tools.

I will tip you $20 if you are perfect, and I will fine you $40 if you miss any important information or change any incorrect information.

Take a deep breath, think step by step, and then analyze the following message:
4 changes: 4 additions & 0 deletions src/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict

from src.memory.data import AddMemory


class State(TypedDict):
user_id: str
messages: Annotated[list, add_messages]
chat_history: Optional[str]
long_term_memories: Optional[list[AddMemory]]
73 changes: 73 additions & 0 deletions src/tools/memory_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os

import psycopg2
from dotenv import load_dotenv
from langchain.tools import StructuredTool

from src.memory.data import AddMemory

load_dotenv()


def modify_memory(
id: str, memory: str, category: str, action: str, memory_old: str = None
):
"""
Function to modify memory in the database
"""
print(f"Modifying memory for {id} with action {action}")
conn = psycopg2.connect(
dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASS"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
)
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS longterm_memory (
id VARCHAR(255) NOT NULL,
memory VARCHAR(255) NOT NULL,
category VARCHAR(255) NOT NULL
);
"""
)
if action == "Create":
cur.execute(
f"""
INSERT INTO longterm_memory (id, memory, category)
VALUES ('{id}', '{memory}', '{category}');
"""
)
conn.commit()
return "Memory created successfully"
elif action == "Update":
cur.execute(
f"""
UPDATE longterm_memory
SET memory = '{memory}'
WHERE id = '{id}' AND memory = '{memory_old}';
"""
)
conn.commit()
return "Memory updated successfully"
elif action == "Delete":
cur.execute(
f"""
DELETE FROM longterm_memory
WHERE id = '{id}' AND memory = '{memory_old}' AND category = '{category}';
"""
)
conn.commit()
return "Memory deleted successfully"
else:
return "Invalid action"


tool_modify_memory = StructuredTool.from_function(
func=modify_memory,
name="modify_memory",
description="Modify the long term memory of a user",
args_schema=AddMemory,
)
Loading