Skip to content

Commit

Permalink
Mirascope deriver (#56)
Browse files Browse the repository at this point in the history
* ready for testing

* delete prompts folder, mirascope colocation ftw

* Fix mirascope integration errors and streaming endpoint

---------

Co-authored-by: vintro <[email protected]>
  • Loading branch information
VVoruganti and vintrocode authored May 15, 2024
1 parent 9993723 commit 04753e4
Show file tree
Hide file tree
Showing 8 changed files with 581 additions and 1,647 deletions.
1,945 changes: 476 additions & 1,469 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = ["Plastic Labs <[email protected]>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.8.1"
python = "^3.9"
fastapi = "^0.109.0"
uvicorn = "^0.24.0.post1"
python-dotenv = "^1.0.0"
Expand All @@ -25,11 +25,10 @@ opentelemetry-instrumentation-logging = "^0.44b0"
greenlet = "^3.0.3"
realtime = "^1.0.2"
psycopg = {extras = ["binary"], version = "^3.1.18"}
langchain = "^0.1.12"
langchain-openai = "^0.0.8"
httpx = "^0.27.0"
uvloop = "^0.19.0"
httptools = "^0.6.1"
mirascope = {extras = ["openai"], version = "^0.12.3"}

[tool.ruff.lint]
# from https://docs.astral.sh/ruff/linter/#rule-selection example
Expand Down
62 changes: 22 additions & 40 deletions src/agent.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,33 @@
import os
import uuid
from typing import Optional

from dotenv import load_dotenv
from langchain_core.prompts import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
load_prompt,
)
from langchain_openai import ChatOpenAI
from mirascope.openai import OpenAICall, OpenAICallParams
from sqlalchemy.ext.asyncio import AsyncSession

from . import crud, schemas

load_dotenv()

# from supabase import Client

SYSTEM_DIALECTIC = load_prompt(
os.path.join(os.path.dirname(__file__), "prompts/dialectic.yaml")
)
system_dialectic: SystemMessagePromptTemplate = SystemMessagePromptTemplate(
prompt=SYSTEM_DIALECTIC
)
class Dialectic(OpenAICall):
prompt_template = """
You are tasked with responding to the query based on the context provided.
---
query: {agent_input}
context: {retrieved_facts}
---
Provide a brief, matter-of-fact, and appropriate response to the query based on the context provided. If the context provided doesn't aid in addressing the query, return None.
"""
agent_input: str
retrieved_facts: str

llm: ChatOpenAI = ChatOpenAI(model_name="gpt-4")
call_params = OpenAICallParams(model="gpt-4o-2024-05-13")


async def prep_inference(
db: AsyncSession,
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
query: str,
):
collection = await crud.get_collection_by_name(db, app_id, user_id, "honcho")
Expand All @@ -56,44 +52,30 @@ async def prep_inference(
if len(retrieved_documents) > 0:
retrieved_facts = retrieved_documents[0].content

dialectic_prompt = ChatPromptTemplate.from_messages([system_dialectic])
chain = dialectic_prompt | llm
return (chain, retrieved_facts)
chain = Dialectic(
agent_input=query,
retrieved_facts=retrieved_facts if retrieved_facts else "None",
)
return chain


async def chat(
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
query: str,
db: AsyncSession,
):
(chain, retrieved_facts) = await prep_inference(
db, app_id, user_id, session_id, query
)
response = await chain.ainvoke(
{
"agent_input": query,
"retrieved_facts": retrieved_facts if retrieved_facts else "None",
}
)
chain = await prep_inference(db, app_id, user_id, query)
response = await chain.call_async()

return schemas.AgentChat(content=response.content)


async def stream(
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
query: str,
db: AsyncSession,
):
(chain, retrieved_facts) = await prep_inference(
db, app_id, user_id, session_id, query
)
return chain.astream(
{
"agent_input": query,
"retrieved_facts": retrieved_facts if retrieved_facts else "None",
}
)
chain = await prep_inference(db, app_id, user_id, query)
return chain.stream_async()
165 changes: 74 additions & 91 deletions src/deriver.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import asyncio
import os
import time
import re
import uuid
from typing import List

import sentry_sdk
from dotenv import load_dotenv
from langchain_core.output_parsers import NumberedListOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
load_prompt,
)
from langchain_openai import ChatOpenAI
from mirascope.openai import OpenAICall, OpenAICallParams
from realtime.connection import Socket
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from websockets.exceptions import ConnectionClosedError

from . import crud, models, schemas
from .db import SessionLocal
Expand All @@ -34,23 +27,37 @@
SUPABASE_ID = os.getenv("SUPABASE_ID")
SUPABASE_API_KEY = os.getenv("SUPABASE_API_KEY")

llm = ChatOpenAI(model_name="gpt-3.5-turbo")
output_parser = NumberedListOutputParser()

SYSTEM_DERIVE_FACTS = load_prompt(
os.path.join(os.path.dirname(__file__), "prompts/derive_facts.yaml")
)
SYSTEM_CHECK_DUPS = load_prompt(
os.path.join(os.path.dirname(__file__), "prompts/check_dup_facts.yaml")
)
class DeriveFacts(OpenAICall):
prompt_template = """
You are tasked with deriving discrete facts about the user based on their input. The goal is to only extract absolute facts from the message, do not make inferences beyond the text provided.
system_check_dups: SystemMessagePromptTemplate = SystemMessagePromptTemplate(
prompt=SYSTEM_CHECK_DUPS
)
chat history: ```{chat_history}```
user input: ```{user_input}```
system_derive_facts: SystemMessagePromptTemplate = SystemMessagePromptTemplate(
prompt=SYSTEM_DERIVE_FACTS
)
Output the facts as a numbered list.
"""

chat_history: str
user_input: str

call_params = OpenAICallParams(model="gpt-4o-2024-05-13")


class CheckDups(OpenAICall):
prompt_template = """
Your job is to determine if the new fact exists in the old:
Old: ```{existing_facts}```
New: ```{fact}```
If the new fact is sufficiently represented in the old list, return False. Otherwise, if the fact is indeed new, return True.
"""
existing_facts: List[str]
fact: str

call_params = OpenAICallParams(model="gpt-4o-2024-05-13")


async def callback(payload):
Expand Down Expand Up @@ -106,6 +113,9 @@ async def process_user_message(
collection_id: uuid.UUID,
message_id: uuid.UUID,
):
"""
Process a user message and derive facts from it (check for duplicates before writing to the collection).
"""
async with SessionLocal() as db:
messages_stmt = await crud.get_messages(
db=db, app_id=app_id, user_id=user_id, session_id=session_id, reverse=True
Expand All @@ -117,16 +127,19 @@ async def process_user_message(
# contents = [m.content for m in messages]
# print(contents)

facts = await derive_facts(messages, content)
chat_history_str = "\n".join(
[f"user: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages]
)
facts_response = await DeriveFacts(
chat_history=chat_history_str, user_input=content
).call_async()
facts = re.findall(r"\d+\.\s([^\n]+)", facts_response.content)

print("===================")
print(f"DERIVED FACTS: {facts}")
print("===================")
new_facts = await check_dups(app_id, user_id, collection_id, facts)

print("===================")
print(f"CHECKED FOR DUPLICATES: {new_facts}")
print("===================")

for fact in new_facts:
create_document = schemas.DocumentCreate(content=fact)
async with SessionLocal() as db:
Expand All @@ -146,81 +159,51 @@ async def process_user_message(
# print(f"Created fact: {fact}")


async def derive_facts(chat_history, input: str) -> List[str]:
"""Derive facts from the user input"""

fact_derivation = ChatPromptTemplate.from_messages([system_derive_facts])
chain = fact_derivation | llm
response = await chain.ainvoke(
{
"chat_history": [
(
"user: " + message.content
if message.is_user
else "ai: " + message.content
)
for message in chat_history
],
"user_input": input,
}
)
facts = output_parser.parse(response.content)

return facts


async def check_dups(
app_id: uuid.UUID, user_id: uuid.UUID, collection_id: uuid.UUID, facts: List[str]
):
"""Check that we're not storing duplicate facts"""

check_duplication = ChatPromptTemplate.from_messages([system_check_dups])
query = " ".join(facts)
check_duplication = CheckDups(existing_facts=[], fact="")
result = None
async with SessionLocal() as db:
result = await crud.query_documents(
db=db,
app_id=app_id,
user_id=user_id,
collection_id=collection_id,
query=query,
top_k=10,
)
# result = collection.query(query=query, top_k=10)
existing_facts = [document.content for document in result]
print("===================")
print(f"Existing Facts {existing_facts}")
print("===================")
if len(existing_facts) == 0:
return facts
chain = check_duplication | llm
response = await chain.ainvoke({"existing_facts": existing_facts, "facts": facts})
new_facts = output_parser.parse(response.content)
new_facts = []
global_existing_facts = [] # for debugging
for fact in facts:
async with SessionLocal() as db:
result = await crud.query_documents(
db=db,
app_id=app_id,
user_id=user_id,
collection_id=collection_id,
query=fact,
top_k=5,
)
existing_facts = [document.content for document in result]
if len(existing_facts) == 0:
new_facts.append(fact)
print(f"New Fact: {fact}")
continue

global_existing_facts.extend(existing_facts) # for debugging

check_duplication.existing_facts = existing_facts
check_duplication.fact = fact
response = await check_duplication.call_async()
if response.content == "True":
new_facts.append(fact)
print(f"New Fact: {fact}")
continue

print("===================")
print(f"New Facts {facts}")
print(f"Existing Facts: {global_existing_facts}")
print(f"Net New Facts {new_facts}")
print("===================")
return new_facts


# def listen_to_websocket(url):
# while True:
# try:
# s = Socket(url)
# s.connect()
# channel = s.set_channel("realtime:public:messages")
# channel.join().on(
# "INSERT", lambda payload: asyncio.create_task(callback(payload))
# )

# s.listen()
# except ConnectionClosedError:
# print("Connection closed, attempting to reconnect...")
# time.sleep(5)


if __name__ == "__main__":
URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={SUPABASE_API_KEY}&vsn=1.0.0"
# URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase
# URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={SUPABASE_API_KEY}&vsn=1.0.0"
URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase
# listen_to_websocket(URL)
s = Socket(URL)
s.connect()
Expand Down
11 changes: 0 additions & 11 deletions src/prompts/check_dup_facts.yaml

This file was deleted.

10 changes: 0 additions & 10 deletions src/prompts/derive_facts.yaml

This file was deleted.

11 changes: 0 additions & 11 deletions src/prompts/dialectic.yaml

This file was deleted.

Loading

0 comments on commit 04753e4

Please sign in to comment.