Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix mirascope integration errors and streaming endpoint
Browse files Browse the repository at this point in the history
VVoruganti committed May 15, 2024
1 parent c776d2b commit 533b615
Showing 3 changed files with 27 additions and 37 deletions.
17 changes: 6 additions & 11 deletions src/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import uuid
from typing import Optional

from dotenv import load_dotenv
from mirascope.openai import OpenAICall, OpenAICallParams
@@ -10,7 +8,6 @@

load_dotenv()

# from supabase import Client

class Dialectic(OpenAICall):
prompt_template = """
@@ -55,7 +52,10 @@ async def prep_inference(
if len(retrieved_documents) > 0:
retrieved_facts = retrieved_documents[0].content

chain = Dialectic(agent_input=query, retrieved_facts=retrieved_facts if retrieved_facts else "None")
chain = Dialectic(
agent_input=query,
retrieved_facts=retrieved_facts if retrieved_facts else "None",
)
return chain


@@ -65,9 +65,7 @@ async def chat(
query: str,
db: AsyncSession,
):
chain = await prep_inference(
db, app_id, user_id, query
)
chain = await prep_inference(db, app_id, user_id, query)
response = await chain.call_async()

return schemas.AgentChat(content=response.content)
@@ -76,11 +74,8 @@ async def chat(
async def stream(
app_id: uuid.UUID,
user_id: uuid.UUID,
session_id: uuid.UUID,
query: str,
db: AsyncSession,
):
chain = await prep_inference(
db, app_id, user_id, query
)
chain = await prep_inference(db, app_id, user_id, query)
return chain.stream_async()
28 changes: 14 additions & 14 deletions src/deriver.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import asyncio
import os, re
import time
import os
import re
import uuid
from typing import List

import sentry_sdk
from dotenv import load_dotenv

from mirascope.openai import OpenAICall, OpenAICallParams

from realtime.connection import Socket
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from websockets.exceptions import ConnectionClosedError

from . import crud, models, schemas
from .db import SessionLocal
@@ -30,6 +27,7 @@
SUPABASE_ID = os.getenv("SUPABASE_ID")
SUPABASE_API_KEY = os.getenv("SUPABASE_API_KEY")


class DeriveFacts(OpenAICall):
prompt_template = """
You are tasked with deriving discrete facts about the user based on their input. The goal is to only extract absolute facts from the message, do not make inferences beyond the text provided.
@@ -45,13 +43,14 @@ class DeriveFacts(OpenAICall):

call_params = OpenAICallParams(model="gpt-4o-2024-05-13")


class CheckDups(OpenAICall):
prompt_template = """
Your job is to determine if the new fact exists in the old:
Old: ```{existing_facts}```
New: ```{facts}```
New: ```{fact}```
If the new fact is sufficiently represented in the old list, return False. Otherwise, if the fact is indeed new, return True.
"""
@@ -128,8 +127,12 @@ async def process_user_message(
# contents = [m.content for m in messages]
# print(contents)

chat_history_str = "\n".join([f"user: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages])
facts_response = await DeriveFacts(chat_history=chat_history_str, user_input=content).call_async()
chat_history_str = "\n".join(
[f"user: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages]
)
facts_response = await DeriveFacts(
chat_history=chat_history_str, user_input=content
).call_async()
facts = re.findall(r"\d+\.\s([^\n]+)", facts_response.content)

print("===================")
@@ -156,7 +159,6 @@ async def process_user_message(
# print(f"Created fact: {fact}")



async def check_dups(
app_id: uuid.UUID, user_id: uuid.UUID, collection_id: uuid.UUID, facts: List[str]
):
@@ -181,7 +183,7 @@ async def check_dups(
new_facts.append(fact)
print(f"New Fact: {fact}")
continue

global_existing_facts.extend(existing_facts) # for debugging

check_duplication.existing_facts = existing_facts
@@ -192,18 +194,16 @@ async def check_dups(
print(f"New Fact: {fact}")
continue


print("===================")
print(f"Existing Facts: {global_existing_facts}")
print(f"Net New Facts {new_facts}")
print("===================")
return new_facts



if __name__ == "__main__":
URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={SUPABASE_API_KEY}&vsn=1.0.0"
# URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase
# URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={SUPABASE_API_KEY}&vsn=1.0.0"
URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase
# listen_to_websocket(URL)
s = Socket(URL)
s.connect()
19 changes: 7 additions & 12 deletions src/routers/sessions.py
Original file line number Diff line number Diff line change
@@ -196,9 +196,7 @@ async def get_chat(
db=db,
auth=Depends(auth),
):
return await agent.chat(
app_id=app_id, user_id=user_id, session_id=session_id, query=query, db=db
)
return await agent.chat(app_id=app_id, user_id=user_id, query=query, db=db)


@router.get(
@@ -221,14 +219,11 @@ async def get_chat_stream(
db=db,
auth=Depends(auth),
):
async def parse_stream():
stream = await agent.stream(app_id=app_id, user_id=user_id, query=query, db=db)
async for chunk in stream:
yield chunk.content

return StreamingResponse(
await agent.stream(
app_id=app_id,
user_id=user_id,
session_id=session_id,
query=query,
db=db,
),
media_type="text/event-stream",
status_code=200,
content=parse_stream(), media_type="text/event-stream", status_code=200
)

0 comments on commit 533b615

Please sign in to comment.