diff --git a/src/agent.py b/src/agent.py index 43e770e..c22d33d 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,6 +1,4 @@ -import os import uuid -from typing import Optional from dotenv import load_dotenv from mirascope.openai import OpenAICall, OpenAICallParams @@ -10,7 +8,6 @@ load_dotenv() -# from supabase import Client class Dialectic(OpenAICall): prompt_template = """ @@ -55,7 +52,10 @@ async def prep_inference( if len(retrieved_documents) > 0: retrieved_facts = retrieved_documents[0].content - chain = Dialectic(agent_input=query, retrieved_facts=retrieved_facts if retrieved_facts else "None") + chain = Dialectic( + agent_input=query, + retrieved_facts=retrieved_facts if retrieved_facts else "None", + ) return chain @@ -65,9 +65,7 @@ async def chat( query: str, db: AsyncSession, ): - chain = await prep_inference( - db, app_id, user_id, query - ) + chain = await prep_inference(db, app_id, user_id, query) response = await chain.call_async() return schemas.AgentChat(content=response.content) @@ -76,11 +74,8 @@ async def chat( async def stream( app_id: uuid.UUID, user_id: uuid.UUID, - session_id: uuid.UUID, query: str, db: AsyncSession, ): - chain = await prep_inference( - db, app_id, user_id, query - ) + chain = await prep_inference(db, app_id, user_id, query) return chain.stream_async() diff --git a/src/deriver.py b/src/deriver.py index aab911f..163ddae 100644 --- a/src/deriver.py +++ b/src/deriver.py @@ -1,18 +1,15 @@ import asyncio -import os, re -import time +import os +import re import uuid from typing import List import sentry_sdk from dotenv import load_dotenv - from mirascope.openai import OpenAICall, OpenAICallParams - from realtime.connection import Socket from sqlalchemy import select from sqlalchemy.orm import selectinload -from websockets.exceptions import ConnectionClosedError from . import crud, models, schemas from .db import SessionLocal @@ -30,6 +27,7 @@ SUPABASE_ID = os.getenv("SUPABASE_ID") SUPABASE_API_KEY = os.getenv("SUPABASE_API_KEY") + class DeriveFacts(OpenAICall): prompt_template = """ You are tasked with deriving discrete facts about the user based on their input. The goal is to only extract absolute facts from the message, do not make inferences beyond the text provided. @@ -45,13 +43,14 @@ class DeriveFacts(OpenAICall): call_params = OpenAICallParams(model="gpt-4o-2024-05-13") + class CheckDups(OpenAICall): prompt_template = """ Your job is to determine if the new fact exists in the old: Old: ```{existing_facts}``` - New: ```{facts}``` + New: ```{fact}``` If the new fact is sufficiently represented in the old list, return False. Otherwise, if the fact is indeed new, return True. """ @@ -128,8 +127,12 @@ async def process_user_message( # contents = [m.content for m in messages] # print(contents) - chat_history_str = "\n".join([f"user: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages]) - facts_response = await DeriveFacts(chat_history=chat_history_str, user_input=content).call_async() + chat_history_str = "\n".join( + [f"user: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages] + ) + facts_response = await DeriveFacts( + chat_history=chat_history_str, user_input=content + ).call_async() facts = re.findall(r"\d+\.\s([^\n]+)", facts_response.content) print("===================") @@ -156,7 +159,6 @@ async def process_user_message( # print(f"Created fact: {fact}") - async def check_dups( app_id: uuid.UUID, user_id: uuid.UUID, collection_id: uuid.UUID, facts: List[str] ): @@ -181,7 +183,7 @@ async def check_dups( new_facts.append(fact) print(f"New Fact: {fact}") continue - + global_existing_facts.extend(existing_facts) # for debugging check_duplication.existing_facts = existing_facts @@ -192,7 +194,6 @@ async def check_dups( print(f"New Fact: {fact}") continue - print("===================") print(f"Existing Facts: {global_existing_facts}") print(f"Net New Facts {new_facts}") @@ -200,10 +201,9 @@ async def check_dups( return new_facts - if __name__ == "__main__": - URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={SUPABASE_API_KEY}&vsn=1.0.0" - # URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase + # URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={SUPABASE_API_KEY}&vsn=1.0.0" + URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase # listen_to_websocket(URL) s = Socket(URL) s.connect() diff --git a/src/routers/sessions.py b/src/routers/sessions.py index 5b83139..1a4bf68 100644 --- a/src/routers/sessions.py +++ b/src/routers/sessions.py @@ -196,9 +196,7 @@ async def get_chat( db=db, auth=Depends(auth), ): - return await agent.chat( - app_id=app_id, user_id=user_id, session_id=session_id, query=query, db=db - ) + return await agent.chat(app_id=app_id, user_id=user_id, query=query, db=db) @router.get( @@ -221,14 +219,11 @@ async def get_chat_stream( db=db, auth=Depends(auth), ): + async def parse_stream(): + stream = await agent.stream(app_id=app_id, user_id=user_id, query=query, db=db) + async for chunk in stream: + yield chunk.content + return StreamingResponse( - await agent.stream( - app_id=app_id, - user_id=user_id, - session_id=session_id, - query=query, - db=db, - ), - media_type="text/event-stream", - status_code=200, + content=parse_stream(), media_type="text/event-stream", status_code=200 )