From f79d98ed39bcb90da7f9138d003d5899293399ec Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:52:22 -0400 Subject: [PATCH 1/5] fix(embedding) use base openai for embedding and remove azure --- .env.template | 8 -------- src/agent.py | 23 +++++------------------ src/crud.py | 17 ++++++++--------- src/deriver/voe.py | 19 +++---------------- 4 files changed, 16 insertions(+), 51 deletions(-) diff --git a/.env.template b/.env.template index f0878f1..31fe4be 100644 --- a/.env.template +++ b/.env.template @@ -3,14 +3,6 @@ CONNECTION_URI=postgresql+psycopg://testuser:testpwd@localhost:5432/honcho # sam # CONNECTION_URI=postgresql+psycopg://testuser:testpwd@database:5432/honcho # sample for docker-compose database OPENAI_API_KEY= -ANTHROPIC_API_KEY= - -# Azure - -AZURE_OPENAI_ENDPOINT= -AZURE_OPENAI_API_KEY= -AZURE_OPENAI_API_VERSION= -AZURE_OPENAI_DEPLOYMENT= # Logging diff --git a/src/agent.py b/src/agent.py index 528634b..c788847 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,11 +1,9 @@ import asyncio -import os import uuid -from typing import Iterable, Set +from collections.abc import Iterable from dotenv import load_dotenv -from mirascope.base import BaseConfig -from mirascope.openai import OpenAICall, OpenAICallParams, azure_client_wrapper +from mirascope.openai import OpenAICall, OpenAICallParams from src import crud, schemas from src.db import SessionLocal @@ -15,7 +13,7 @@ class AsyncSet: def __init__(self): - self._set: Set[str] = set() + self._set: set[str] = set() self._lock = asyncio.Lock() async def add(self, item: str): @@ -26,7 +24,7 @@ async def update(self, items: Iterable[str]): async with self._lock: self._set.update(items) - def get_set(self) -> Set[str]: + def get_set(self) -> set[str]: return self._set.copy() @@ -44,18 +42,7 @@ class Dialectic(OpenAICall): retrieved_facts: str chat_history: list[str] - configuration = BaseConfig( - client_wrappers=[ - azure_client_wrapper( - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - ) - ] - ) - call_params = OpenAICallParams( - model=os.getenv("AZURE_OPENAI_DEPLOYMENT"), temperature=1.2, top_p=0.5 - ) + call_params = OpenAICallParams(model="gpt-4o", temperature=1.2, top_p=0.5) async def chat_history( diff --git a/src/crud.py b/src/crud.py index c0ce80b..2e9bc8b 100644 --- a/src/crud.py +++ b/src/crud.py @@ -3,7 +3,8 @@ import uuid from typing import Optional, Sequence -from openai import AzureOpenAI, OpenAI +from dotenv import load_dotenv +from openai import OpenAI from sqlalchemy import Select, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -11,11 +12,9 @@ # from sqlalchemy.orm import Session from . import models, schemas -openai_client = AzureOpenAI( - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), -) +load_dotenv(override=True) + +openai_client = OpenAI() ######################################################## # app methods @@ -711,7 +710,7 @@ async def query_documents( top_k: int = 5, ) -> Sequence[models.Document]: response = openai_client.embeddings.create( - input=query, model=os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT") + model="text-embedding-3-small", input=query ) embedding_query = response.data[0].embedding stmt = ( @@ -747,7 +746,7 @@ async def create_document( raise ValueError("Session not found or does not belong to user") response = openai_client.embeddings.create( - input=document.content, model=os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT") + input=document.content, model="text-embedding-3-small" ) embedding = response.data[0].embedding @@ -784,7 +783,7 @@ async def update_document( if document.content is not None: honcho_document.content = document.content response = openai_client.embeddings.create( - input=document.content, model=os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT") + input=document.content, model="text-embedding-3-small" ) embedding = response.data[0].embedding honcho_document.embedding = embedding diff --git a/src/deriver/voe.py b/src/deriver/voe.py index 8c7b6b9..4d7b9be 100644 --- a/src/deriver/voe.py +++ b/src/deriver/voe.py @@ -1,24 +1,11 @@ -import os -from typing import List - -from mirascope.base import BaseConfig -from mirascope.openai import OpenAICall, OpenAICallParams, azure_client_wrapper +from mirascope.openai import OpenAICall, OpenAICallParams from pydantic import ConfigDict class HonchoCall(OpenAICall): model_config = ConfigDict(arbitrary_types_allowed=True) - configuration = BaseConfig( - client_wrappers=[ - azure_client_wrapper( - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore - ) - ] - ) call_params = OpenAICallParams( - model=os.getenv("AZURE_OPENAI_DEPLOYMENT"), # type: ignore + model="gpt-4o", # type: ignore temperature=1.2, top_p=0.5, ) @@ -143,5 +130,5 @@ class CheckVoeList(HonchoCall): If you believe the new fact is sufficiently new given the ones in the list, output true. If not, output false. Do not provide extra commentary, only output a boolean value. ''' - existing_facts: List[str] + existing_facts: list[str] new_fact: str From fe49dfc8b8d54444184c7dbf7e291922c3088ec4 Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Wed, 16 Oct 2024 15:43:54 -0400 Subject: [PATCH 2/5] feat(models) switch to nanoids with internal and public id system --- .github/workflows/unittest.yml | 5 - pyproject.toml | 1 + src/agent.py | 23 ++- src/crud.py | 289 +++++++++++++++++---------------- src/deriver/consumer.py | 26 ++- src/models.py | 119 ++++++++------ src/routers/apps.py | 16 +- src/routers/collections.py | 35 ++-- src/routers/documents.py | 43 +++-- src/routers/messages.py | 35 ++-- src/routers/metamessages.py | 39 +++-- src/routers/sessions.py | 65 ++++---- src/routers/users.py | 29 ++-- src/schemas.py | 95 ++++------- tests/conftest.py | 6 +- tests/routes/test_apps.py | 10 +- tests/routes/test_users.py | 8 +- uv.lock | 11 ++ 18 files changed, 421 insertions(+), 434 deletions(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 3586002..6c47961 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -51,11 +51,6 @@ jobs: SENTRY_ENABLED: false OPENTELEMETRY_ENABLED: false OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} - AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} - AZURE_OPENAI_API_VERSION: ${{ secrets.AZURE_OPENAI_API_VERSION }} - AZURE_OPENAI_DEPLOYMENT: ${{ secrets.AZURE_OPENAI_DEPLOYMENT }} - AZURE_OPENAI_EMBED_DEPLOYMENT: ${{ secrets.AZURE_OPENAI_EMBED_DEPLOYMENT }} diff --git a/pyproject.toml b/pyproject.toml index e6f1893..bda1cc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "rich>=13.7.1", "mirascope>=0.18.0", "openai>=1.43.0", + "nanoid>=2.0.0", ] [tool.uv] dev-dependencies = [ diff --git a/src/agent.py b/src/agent.py index c788847..70ebcc5 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,5 +1,4 @@ import asyncio -import uuid from collections.abc import Iterable from dotenv import load_dotenv @@ -45,9 +44,7 @@ class Dialectic(OpenAICall): call_params = OpenAICallParams(model="gpt-4o", temperature=1.2, top_p=0.5) -async def chat_history( - app_id: uuid.UUID, user_id: uuid.UUID, session_id: uuid.UUID -) -> list[str]: +async def chat_history(app_id: str, user_id: str, session_id: str) -> list[str]: async with SessionLocal() as db: stmt = await crud.get_messages(db, app_id, user_id, session_id) results = await db.execute(stmt) @@ -62,8 +59,8 @@ async def chat_history( async def prep_inference( - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, query: str, collection_name: str, ) -> None | list[str]: @@ -88,8 +85,8 @@ async def prep_inference( async def generate_facts( - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, fact_set: AsyncSet, collection_name: str, questions: list[str], @@ -103,8 +100,8 @@ async def fetch_facts(query): async def fact_generator( - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, collections: list[str], questions: list[str], ): @@ -121,9 +118,9 @@ async def fact_generator( async def chat( - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, query: schemas.AgentQuery, stream: bool = False, ): diff --git a/src/crud.py b/src/crud.py index 2e9bc8b..9d2ca80 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,7 +1,6 @@ import datetime -import os -import uuid -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from dotenv import load_dotenv from openai import OpenAI @@ -21,8 +20,8 @@ ######################################################## -async def get_app(db: AsyncSession, app_id: uuid.UUID) -> Optional[models.App]: - stmt = select(models.App).where(models.App.id == app_id) +async def get_app(db: AsyncSession, app_id: str) -> Optional[models.App]: + stmt = select(models.App).where(models.App.public_id == app_id) result = await db.execute(stmt) app = result.scalar_one_or_none() return app @@ -48,7 +47,7 @@ async def create_app(db: AsyncSession, app: schemas.AppCreate) -> models.App: async def update_app( - db: AsyncSession, app_id: uuid.UUID, app: schemas.AppUpdate + db: AsyncSession, app_id: str, app: schemas.AppUpdate ) -> models.App: honcho_app = await get_app(db, app_id) if honcho_app is None: @@ -63,7 +62,7 @@ async def update_app( return honcho_app -# def delete_app(db: AsyncSession, app_id: uuid.UUID) -> bool: +# def delete_app(db: AsyncSession, app_id: str) -> bool: # existing_app = get_app(db, app_id) # if existing_app is None: # return False @@ -78,7 +77,7 @@ async def update_app( async def create_user( - db: AsyncSession, app_id: uuid.UUID, user: schemas.UserCreate + db: AsyncSession, app_id: str, user: schemas.UserCreate ) -> models.User: honcho_user = models.User( app_id=app_id, @@ -92,12 +91,12 @@ async def create_user( async def get_user( - db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID + db: AsyncSession, app_id: str, user_id: str ) -> Optional[models.User]: stmt = ( select(models.User) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) + .where(models.User.public_id == user_id) ) result = await db.execute(stmt) user = result.scalar_one_or_none() @@ -105,7 +104,7 @@ async def get_user( async def get_user_by_name( - db: AsyncSession, app_id: uuid.UUID, name: str + db: AsyncSession, app_id: str, name: str ) -> Optional[models.User]: stmt = ( select(models.User) @@ -119,7 +118,7 @@ async def get_user_by_name( async def get_users( db: AsyncSession, - app_id: uuid.UUID, + app_id: str, reverse: bool = False, filter: Optional[dict] = None, ) -> Select: @@ -137,7 +136,7 @@ async def get_users( async def update_user( - db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID, user: schemas.UserUpdate + db: AsyncSession, app_id: str, user_id: str, user: schemas.UserUpdate ) -> models.User: honcho_user = await get_user(db, app_id, user_id) if honcho_user is None: @@ -152,7 +151,7 @@ async def update_user( return honcho_user -# def delete_user(db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID) -> bool: +# def delete_user(db: AsyncSession, app_id: str, user_id: str) -> bool: # existing_user = get_user(db, app_id, user_id) # if existing_user is None: # return False @@ -167,15 +166,15 @@ async def update_user( async def get_session( db: AsyncSession, - app_id: uuid.UUID, - session_id: uuid.UUID, - user_id: Optional[uuid.UUID] = None, + app_id: str, + session_id: str, + user_id: Optional[str] = None, ) -> Optional[models.Session]: stmt = ( select(models.Session) - .join(models.User, models.User.id == models.Session.user_id) + .join(models.User, models.User.public_id == models.Session.user_id) .where(models.User.app_id == app_id) - .where(models.Session.id == session_id) + .where(models.Session.public_id == session_id) ) if user_id is not None: stmt = stmt.where(models.Session.user_id == user_id) @@ -186,15 +185,15 @@ async def get_session( async def get_sessions( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, reverse: Optional[bool] = False, is_active: Optional[bool] = False, filter: Optional[dict] = None, ) -> Select: stmt = ( select(models.Session) - .join(models.User, models.User.id == models.Session.user_id) + .join(models.User, models.User.public_id == models.Session.user_id) .where(models.User.app_id == app_id) .where(models.Session.user_id == user_id) ) @@ -216,8 +215,8 @@ async def get_sessions( async def create_session( db: AsyncSession, session: schemas.SessionCreate, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, ) -> models.Session: honcho_user = await get_user(db, app_id=app_id, user_id=user_id) if honcho_user is None: @@ -245,9 +244,9 @@ async def create_session( async def update_session( db: AsyncSession, session: schemas.SessionUpdate, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, ) -> bool: honcho_session = await get_session( db, app_id=app_id, session_id=session_id, user_id=user_id @@ -264,12 +263,12 @@ async def update_session( async def delete_session( - db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID, session_id: uuid.UUID + db: AsyncSession, app_id: str, user_id: str, session_id: str ) -> bool: stmt = ( select(models.Session) - .join(models.User, models.User.id == models.Session.user_id) - .where(models.Session.id == session_id) + .join(models.User, models.User.public_id == models.Session.user_id) + .where(models.Session.public_id == session_id) .where(models.User.app_id == app_id) .where(models.Session.user_id == user_id) ) @@ -290,9 +289,9 @@ async def delete_session( async def create_message( db: AsyncSession, message: schemas.MessageCreate, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, ) -> models.Message: honcho_session = await get_session( db, app_id=app_id, session_id=session_id, user_id=user_id @@ -315,19 +314,19 @@ async def create_message( async def get_messages( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, reverse: Optional[bool] = False, filter: Optional[dict] = None, ) -> Select: stmt = ( select(models.Message) - .join(models.Session, models.Session.id == models.Message.session_id) - .join(models.User, models.User.id == models.Session.user_id) - .join(models.App, models.App.id == models.User.app_id) - .where(models.App.id == app_id) - .where(models.User.id == user_id) + .join(models.Session, models.Session.public_id == models.Message.session_id) + .join(models.User, models.User.public_id == models.Session.user_id) + .join(models.App, models.App.public_id == models.User.app_id) + .where(models.App.public_id == app_id) + .where(models.User.public_id == user_id) .where(models.Message.session_id == session_id) ) @@ -344,20 +343,20 @@ async def get_messages( async def get_message( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + message_id: str, ) -> Optional[models.Message]: stmt = ( select(models.Message) - .join(models.Session, models.Session.id == models.Message.session_id) - .join(models.User, models.User.id == models.Session.user_id) - .join(models.App, models.App.id == models.User.app_id) - .where(models.App.id == app_id) - .where(models.User.id == user_id) + .join(models.Session, models.Session.public_id == models.Message.session_id) + .join(models.User, models.User.public_id == models.Session.user_id) + .join(models.App, models.App.public_id == models.User.app_id) + .where(models.App.public_id == app_id) + .where(models.User.public_id == user_id) .where(models.Message.session_id == session_id) - .where(models.Message.id == message_id) + .where(models.Message.public_id == message_id) ) result = await db.execute(stmt) return result.scalar_one_or_none() @@ -366,10 +365,10 @@ async def get_message( async def update_message( db: AsyncSession, message: schemas.MessageUpdate, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + message_id: str, ) -> bool: honcho_message = await get_message( db, app_id=app_id, session_id=session_id, user_id=user_id, message_id=message_id @@ -393,9 +392,9 @@ async def update_message( async def create_metamessage( db: AsyncSession, metamessage: schemas.MetamessageCreate, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, ): message = await get_message( db, @@ -422,22 +421,22 @@ async def create_metamessage( async def get_metamessages( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: Optional[uuid.UUID], + app_id: str, + user_id: str, + session_id: str, + message_id: Optional[str], metamessage_type: Optional[str] = None, filter: Optional[dict] = None, reverse: Optional[bool] = False, ) -> Select: stmt = ( select(models.Metamessage) - .join(models.Message, models.Message.id == models.Metamessage.message_id) - .join(models.Session, models.Message.session_id == models.Session.id) - .join(models.User, models.User.id == models.Session.user_id) - .join(models.App, models.App.id == models.User.app_id) - .where(models.App.id == app_id) - .where(models.User.id == user_id) + .join(models.Message, models.Message.public_id == models.Metamessage.message_id) + .join(models.Session, models.Message.session_id == models.Session.public_id) + .join(models.User, models.User.public_id == models.Session.user_id) + .join(models.App, models.App.public_id == models.User.app_id) + .where(models.App.public_id == app_id) + .where(models.User.public_id == user_id) .where(models.Message.session_id == session_id) ) @@ -460,23 +459,23 @@ async def get_metamessages( async def get_metamessage( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: uuid.UUID, - metamessage_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + message_id: str, + metamessage_id: str, ) -> Optional[models.Metamessage]: stmt = ( select(models.Metamessage) - .join(models.Message, models.Message.id == models.Metamessage.message_id) - .join(models.Session, models.Message.session_id == models.Session.id) - .join(models.User, models.User.id == models.Session.user_id) - .join(models.App, models.App.id == models.User.app_id) - .where(models.App.id == app_id) - .where(models.User.id == user_id) + .join(models.Message, models.Message.public_id == models.Metamessage.message_id) + .join(models.Session, models.Message.session_id == models.Session.public_id) + .join(models.User, models.User.public_id == models.Session.user_id) + .join(models.App, models.App.public_id == models.User.app_id) + .where(models.App.public_id == app_id) + .where(models.User.public_id == user_id) .where(models.Message.session_id == session_id) .where(models.Metamessage.message_id == message_id) - .where(models.Metamessage.id == metamessage_id) + .where(models.Metamessage.public_id == metamessage_id) ) result = await db.execute(stmt) return result.scalar_one_or_none() @@ -485,10 +484,10 @@ async def get_metamessage( async def update_metamessage( db: AsyncSession, metamessage: schemas.MetamessageUpdate, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - metamessage_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + metamessage_id: str, ) -> bool: honcho_metamessage = await get_metamessage( db, @@ -520,17 +519,17 @@ async def update_metamessage( async def get_collections( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, reverse: Optional[bool] = False, filter: Optional[dict] = None, ) -> Select: """Get a distinct list of the names of collections associated with a user""" stmt = ( select(models.Collection) - .join(models.User, models.User.id == models.Collection.user_id) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) + .where(models.User.public_id == user_id) ) if filter is not None: @@ -545,14 +544,14 @@ async def get_collections( async def get_collection_by_id( - db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID, collection_id: uuid.UUID + db: AsyncSession, app_id: str, user_id: str, collection_id: str ) -> Optional[models.Collection]: stmt = ( select(models.Collection) - .join(models.User, models.User.id == models.Collection.user_id) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) - .where(models.Collection.id == collection_id) + .where(models.User.public_id == user_id) + .where(models.Collection.public_id == collection_id) ) result = await db.execute(stmt) collection = result.scalar_one_or_none() @@ -560,13 +559,13 @@ async def get_collection_by_id( async def get_collection_by_name( - db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID, name: str + db: AsyncSession, app_id: str, user_id: str, name: str ) -> Optional[models.Collection]: stmt = ( select(models.Collection) - .join(models.User, models.User.id == models.Collection.user_id) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) + .where(models.User.public_id == user_id) .where(models.Collection.name == name) ) result = await db.execute(stmt) @@ -577,8 +576,8 @@ async def get_collection_by_name( async def create_collection( db: AsyncSession, collection: schemas.CollectionCreate, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, ) -> models.Collection: honcho_collection = models.Collection( user_id=user_id, @@ -598,9 +597,9 @@ async def create_collection( async def update_collection( db: AsyncSession, collection: schemas.CollectionUpdate, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, ) -> models.Collection: honcho_collection = await get_collection_by_id( db, app_id=app_id, user_id=user_id, collection_id=collection_id @@ -621,7 +620,7 @@ async def update_collection( async def delete_collection( - db: AsyncSession, app_id: uuid.UUID, user_id: uuid.UUID, collection_id: uuid.UUID + db: AsyncSession, app_id: str, user_id: str, collection_id: str ) -> bool: """ Delete a Collection and all documents associated with it. Takes advantage of @@ -629,10 +628,10 @@ async def delete_collection( """ stmt = ( select(models.Collection) - .join(models.User, models.User.id == models.Collection.user_id) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) - .where(models.Collection.id == collection_id) + .where(models.User.public_id == user_id) + .where(models.Collection.public_id == collection_id) ) result = await db.execute(stmt) honcho_collection = result.scalar_one_or_none() @@ -652,18 +651,21 @@ async def delete_collection( async def get_documents( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, reverse: Optional[bool] = False, filter: Optional[dict] = None, ) -> Select: stmt = ( select(models.Document) - .join(models.Collection, models.Collection.id == models.Document.collection_id) - .join(models.User, models.User.id == models.Collection.user_id) + .join( + models.Collection, + models.Collection.public_id == models.Document.collection_id, + ) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) + .where(models.User.public_id == user_id) .where(models.Document.collection_id == collection_id) ) @@ -680,19 +682,22 @@ async def get_documents( async def get_document( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, - document_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, + document_id: str, ) -> Optional[models.Document]: stmt = ( select(models.Document) - .join(models.Collection, models.Collection.id == models.Document.collection_id) - .join(models.User, models.User.id == models.Collection.user_id) + .join( + models.Collection, + models.Collection.public_id == models.Document.collection_id, + ) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) + .where(models.User.public_id == user_id) .where(models.Document.collection_id == collection_id) - .where(models.Document.id == document_id) + .where(models.Document.public_id == document_id) ) result = await db.execute(stmt) @@ -702,9 +707,9 @@ async def get_document( async def query_documents( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, query: str, filter: Optional[dict] = None, top_k: int = 5, @@ -715,10 +720,13 @@ async def query_documents( embedding_query = response.data[0].embedding stmt = ( select(models.Document) - .join(models.Collection, models.Collection.id == models.Document.collection_id) - .join(models.User, models.User.id == models.Collection.user_id) + .join( + models.Collection, + models.Collection.public_id == models.Document.collection_id, + ) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) + .where(models.User.public_id == user_id) .where(models.Document.collection_id == collection_id) # .limit(top_k) ) @@ -734,9 +742,9 @@ async def query_documents( async def create_document( db: AsyncSession, document: schemas.DocumentCreate, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, ) -> models.Document: """Embed a message as a vector and create a document""" collection = await get_collection_by_id( @@ -766,10 +774,10 @@ async def create_document( async def update_document( db: AsyncSession, document: schemas.DocumentUpdate, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, - document_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, + document_id: str, ) -> bool: honcho_document = await get_document( db, @@ -798,19 +806,22 @@ async def update_document( async def delete_document( db: AsyncSession, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, - document_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, + document_id: str, ) -> bool: stmt = ( select(models.Document) - .join(models.Collection, models.Collection.id == models.Document.collection_id) - .join(models.User, models.User.id == models.Collection.user_id) + .join( + models.Collection, + models.Collection.public_id == models.Document.collection_id, + ) + .join(models.User, models.User.public_id == models.Collection.user_id) .where(models.User.app_id == app_id) - .where(models.User.id == user_id) + .where(models.User.public_id == user_id) .where(models.Document.collection_id == collection_id) - .where(models.Document.id == document_id) + .where(models.Document.public_id == document_id) ) result = await db.execute(stmt) document = result.scalar_one_or_none() diff --git a/src/deriver/consumer.py b/src/deriver/consumer.py index 8aeca7c..afde8b8 100644 --- a/src/deriver/consumer.py +++ b/src/deriver/consumer.py @@ -1,7 +1,5 @@ import logging import re -import uuid -from typing import List from dotenv import load_dotenv from rich import print as rprint @@ -67,11 +65,11 @@ async def process_item(db: AsyncSession, payload: dict): async def process_ai_message( content: str, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - collection_id: uuid.UUID, - message_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + collection_id: str, + message_id: str, db: AsyncSession, ): """ @@ -192,11 +190,11 @@ async def process_ai_message( async def process_user_message( content: str, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - collection_id: uuid.UUID, - message_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + collection_id: str, + message_id: str, db: AsyncSession, ): """ @@ -302,9 +300,7 @@ async def process_user_message( return -async def check_dups( - app_id: uuid.UUID, user_id: uuid.UUID, collection_id: uuid.UUID, facts: List[str] -): +async def check_dups(app_id: str, user_id: str, collection_id: str, facts: list[str]): """Check that we're not storing duplicate facts""" check_duplication = CheckVoeList(existing_facts=[], new_fact="") diff --git a/src/models.py b/src/models.py index 289f9cb..4fcd305 100644 --- a/src/models.py +++ b/src/models.py @@ -1,18 +1,16 @@ import datetime -import uuid from dotenv import load_dotenv +from nanoid import generate as generate_nanoid from pgvector.sqlalchemy import Vector from sqlalchemy import ( - JSON, + BigInteger, Boolean, - Column, DateTime, ForeignKey, - Integer, + Identity, String, UniqueConstraint, - Uuid, ) from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -22,36 +20,37 @@ load_dotenv() -# DATABASE_TYPE = os.getenv("DATABASE_TYPE", "postgres") - -# ColumnType = JSONB if DATABASE_TYPE == "postgres" else JSON - class App(Base): __tablename__ = "apps" - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, index=True, default=uuid.uuid4 + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, index=True, autoincrement=True + ) + public_id: Mapped[str] = mapped_column( + String(21), index=True, unique=True, default=generate_nanoid ) name: Mapped[str] = mapped_column(String(512), index=True, unique=True) users = relationship("User", back_populates="app") created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), default=func.now() + DateTime(timezone=True), index=True, default=func.now() ) - h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) - # Add any additional fields for an app here + h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) class User(Base): __tablename__ = "users" - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, index=True, default=uuid.uuid4 + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, index=True, autoincrement=True + ) + public_id: Mapped[str] = mapped_column( + String(21), index=True, unique=True, default=generate_nanoid ) name: Mapped[str] = mapped_column(String(512), index=True) - h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), default=func.now() + DateTime(timezone=True), index=True, default=func.now() ) - app_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("apps.id"), index=True) + app_id: Mapped[str] = mapped_column(ForeignKey("apps.public_id"), index=True) app = relationship("App", back_populates="users") sessions = relationship("Session", back_populates="user") collections = relationship("Collection", back_populates="user") @@ -64,16 +63,19 @@ def __repr__(self) -> str: class Session(Base): __tablename__ = "sessions" - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, index=True, default=uuid.uuid4 + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, index=True, autoincrement=True + ) + public_id: Mapped[str] = mapped_column( + String(21), index=True, unique=True, default=generate_nanoid ) is_active: Mapped[bool] = mapped_column(default=True) - h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), default=func.now() + DateTime(timezone=True), index=True, default=func.now() ) messages = relationship("Message", back_populates="session") - user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), index=True) + user_id: Mapped[str] = mapped_column(ForeignKey("users.public_id"), index=True) user = relationship("User", back_populates="sessions") def __repr__(self) -> str: @@ -82,16 +84,21 @@ def __repr__(self) -> str: class Message(Base): __tablename__ = "messages" - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, index=True, default=uuid.uuid4 + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, index=True, autoincrement=True + ) + public_id: Mapped[str] = mapped_column( + String(21), index=True, unique=True, default=generate_nanoid + ) + session_id: Mapped[str] = mapped_column( + ForeignKey("sessions.public_id"), index=True ) - session_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("sessions.id"), index=True) is_user: Mapped[bool] content: Mapped[str] = mapped_column(String(65535)) - h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), default=func.now() + DateTime(timezone=True), index=True, default=func.now() ) session = relationship("Session", back_populates="messages") metamessages = relationship("Metamessage", back_populates="message") @@ -102,18 +109,23 @@ def __repr__(self) -> str: class Metamessage(Base): __tablename__ = "metamessages" - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, index=True, default=uuid.uuid4 + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, index=True, autoincrement=True + ) + public_id: Mapped[str] = mapped_column( + String(21), index=True, unique=True, default=generate_nanoid ) metamessage_type: Mapped[str] = mapped_column(String(512), index=True) content: Mapped[str] = mapped_column(String(65535)) - message_id = Column(Uuid, ForeignKey("messages.id"), index=True) + message_id: Mapped[str] = mapped_column( + ForeignKey("messages.public_id"), index=True + ) message = relationship("Message", back_populates="metamessages") created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), default=func.now() + DateTime(timezone=True), index=True, default=func.now() ) - h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) def __repr__(self) -> str: return f"Metamessages(id={self.id}, message_id={self.message_id}, metamessage_type={self.metamessage_type}, content={self.content[10:]})" @@ -121,19 +133,25 @@ def __repr__(self) -> str: class Collection(Base): __tablename__ = "collections" - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, index=True, default=uuid.uuid4 + + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, index=True, autoincrement=True + ) + public_id: Mapped[str] = mapped_column( + String(21), index=True, unique=True, default=generate_nanoid ) name: Mapped[str] = mapped_column(String(512), index=True) created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), default=func.now() + DateTime(timezone=True), index=True, default=func.now() ) - h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) documents = relationship( "Document", back_populates="collection", cascade="all, delete, delete-orphan" ) user = relationship("User", back_populates="collections") - user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"), index=True) + user_id: Mapped[str] = mapped_column( + String(21), ForeignKey("users.public_id"), index=True + ) __table_args__ = ( UniqueConstraint("name", "user_id", name="unique_name_collection_user"), @@ -142,24 +160,31 @@ class Collection(Base): class Document(Base): __tablename__ = "documents" - id: Mapped[uuid.UUID] = mapped_column( - primary_key=True, index=True, default=uuid.uuid4 + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, index=True, autoincrement=True ) - h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) + public_id: Mapped[str] = mapped_column( + String(21), index=True, unique=True, default=generate_nanoid + ) + h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) content: Mapped[str] = mapped_column(String(65535)) embedding = mapped_column(Vector(1536)) created_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), default=func.now() + DateTime(timezone=True), index=True, default=func.now() ) - collection_id = Column(Uuid, ForeignKey("collections.id"), index=True) + collection_id: Mapped[str] = mapped_column( + String(21), ForeignKey("collections.public_id"), index=True + ) collection = relationship("Collection", back_populates="documents") class QueueItem(Base): __tablename__ = "queue" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - session_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("sessions.id"), index=True) + id: Mapped[int] = mapped_column( + BigInteger, Identity(), primary_key=True, autoincrement=True + ) + session_id: Mapped[int] = mapped_column(ForeignKey("sessions.id"), index=True) payload: Mapped[dict] = mapped_column(JSONB, nullable=False) processed: Mapped[bool] = mapped_column(Boolean, default=False) @@ -167,7 +192,7 @@ class QueueItem(Base): class ActiveQueueSession(Base): __tablename__ = "active_queue_sessions" - session_id: Mapped[uuid.UUID] = mapped_column(primary_key=True, index=True) + session_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, index=True) last_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=func.now(), onupdate=func.now() ) diff --git a/src/routers/apps.py b/src/routers/apps.py index 393b1cc..dd51331 100644 --- a/src/routers/apps.py +++ b/src/routers/apps.py @@ -1,13 +1,7 @@ -import os import traceback -import uuid -from typing import Optional -import httpx from fastapi import APIRouter, Depends, HTTPException, Request -from psycopg.errors import UniqueViolation from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession from src import crud, schemas from src.dependencies import db @@ -20,13 +14,11 @@ @router.get("/{app_id}", response_model=schemas.App) -async def get_app( - request: Request, app_id: uuid.UUID, db=db, auth: dict = Depends(auth) -): +async def get_app(request: Request, app_id: str, db=db, auth: dict = Depends(auth)): """Get an App by ID Args: - app_id (uuid.UUID): The ID of the app + app_id (str): The ID of the app Returns: schemas.App: App object @@ -125,7 +117,7 @@ async def get_or_create_app(request: Request, name: str, db=db, auth=Depends(aut @router.put("/{app_id}", response_model=schemas.App) async def update_app( request: Request, - app_id: uuid.UUID, + app_id: str, app: schemas.AppUpdate, db=db, auth=Depends(auth), @@ -133,7 +125,7 @@ async def update_app( """Update an App Args: - app_id (uuid.UUID): The ID of the app to update + app_id (str): The ID of the app to update app (schemas.AppUpdate): The App object containing any new metadata Returns: diff --git a/src/routers/collections.py b/src/routers/collections.py index ee95763..7ec0ab3 100644 --- a/src/routers/collections.py +++ b/src/routers/collections.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request @@ -19,8 +18,8 @@ @router.get("", response_model=Page[schemas.Collection]) async def get_collections( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, reverse: Optional[bool] = False, filter: Optional[str] = None, db=db, @@ -29,9 +28,9 @@ async def get_collections( """Get All Collections for a User Args: - app_id (uuid.UUID): The ID of the app representing the client + app_id (str): The ID of the app representing the client application using honcho - user_id (uuid.UUID): The User ID representing the user, managed by the user + user_id (str): The User ID representing the user, managed by the user Returns: list[schemas.Collection]: List of Collection objects @@ -51,8 +50,8 @@ async def get_collections( @router.get("/name/{name}", response_model=schemas.Collection) async def get_collection_by_name( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, name: str, db=db, auth=Depends(auth), @@ -70,9 +69,9 @@ async def get_collection_by_name( @router.get("/{collection_id}", response_model=schemas.Collection) async def get_collection_by_id( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, db=db, auth=Depends(auth), ) -> schemas.Collection: @@ -89,8 +88,8 @@ async def get_collection_by_id( @router.post("", response_model=schemas.Collection) async def create_collection( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, collection: schemas.CollectionCreate, db=db, auth=Depends(auth), @@ -114,9 +113,9 @@ async def create_collection( @router.put("/{collection_id}", response_model=schemas.Collection) async def update_collection( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, collection: schemas.CollectionUpdate, db=db, auth=Depends(auth), @@ -150,9 +149,9 @@ async def update_collection( @router.delete("/{collection_id}") async def delete_collection( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, db=db, auth=Depends(auth), ): diff --git a/src/routers/documents.py b/src/routers/documents.py index 7fbd3e7..89819b6 100644 --- a/src/routers/documents.py +++ b/src/routers/documents.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Optional, Sequence from fastapi import APIRouter, Depends, HTTPException, Request @@ -19,9 +18,9 @@ @router.get("", response_model=Page[schemas.Document]) async def get_documents( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, reverse: Optional[bool] = False, filter: Optional[str] = None, db=db, @@ -56,10 +55,10 @@ async def get_documents( ) async def get_document( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, - document_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, + document_id: str, db=db, auth=Depends(auth), ): @@ -80,9 +79,9 @@ async def get_document( @router.get("/query", response_model=Sequence[schemas.Document]) async def query_documents( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, query: str, top_k: int = 5, filter: Optional[str] = None, @@ -108,9 +107,9 @@ async def query_documents( @router.post("", response_model=schemas.Document) async def create_document( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, document: schemas.DocumentCreate, db=db, auth=Depends(auth), @@ -135,10 +134,10 @@ async def create_document( ) async def update_document( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, - document_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, + document_id: str, document: schemas.DocumentUpdate, db=db, auth=Depends(auth), @@ -165,10 +164,10 @@ async def update_document( @router.delete("/{document_id}") async def delete_document( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - collection_id: uuid.UUID, - document_id: uuid.UUID, + app_id: str, + user_id: str, + collection_id: str, + document_id: str, db=db, auth=Depends(auth), ): diff --git a/src/routers/messages.py b/src/routers/messages.py index 95ee2b6..98625e7 100644 --- a/src/routers/messages.py +++ b/src/routers/messages.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Optional from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request @@ -41,7 +40,7 @@ async def enqueue(payload: dict): return try: processed_payload = { - k: str(v) if isinstance(v, uuid.UUID) else v for k, v in payload.items() + k: str(v) if isinstance(v, str) else v for k, v in payload.items() } item = QueueItem( payload=processed_payload, session_id=payload["session_id"] @@ -60,9 +59,9 @@ async def enqueue(payload: dict): @router.post("", response_model=schemas.Message) async def create_message_for_session( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, message: schemas.MessageCreate, background_tasks: BackgroundTasks, db=db, @@ -71,7 +70,7 @@ async def create_message_for_session( """Adds a message to a session Args: - app_id (uuid.UUID): The ID of the app representing the client application using honcho + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user session_id (int): The ID of the Session to add the message to message (schemas.MessageCreate): The Message object to add containing the message content and type @@ -109,9 +108,9 @@ async def create_message_for_session( @router.get("", response_model=Page[schemas.Message]) async def get_messages( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, reverse: Optional[bool] = False, filter: Optional[str] = None, db=db, @@ -120,7 +119,7 @@ async def get_messages( """Get all messages for a session Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user session_id (int): The ID of the Session to retrieve @@ -155,10 +154,10 @@ async def get_messages( @router.get("/{message_id}", response_model=schemas.Message) async def get_message( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + message_id: str, db=db, auth=Depends(auth), ): @@ -174,10 +173,10 @@ async def get_message( @router.put("/{message_id}", response_model=schemas.Message) async def update_message( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + message_id: str, message: schemas.MessageUpdate, db=db, auth=Depends(auth), diff --git a/src/routers/metamessages.py b/src/routers/metamessages.py index a125016..3a44634 100644 --- a/src/routers/metamessages.py +++ b/src/routers/metamessages.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request @@ -19,9 +18,9 @@ @router.post("", response_model=schemas.Metamessage) async def create_metamessage( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, metamessage: schemas.MetamessageCreate, db=db, auth=Depends(auth), @@ -29,7 +28,7 @@ async def create_metamessage( """Adds a message to a session Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user session_id (int): The ID of the Session to add the message to @@ -57,10 +56,10 @@ async def create_metamessage( @router.get("", response_model=Page[schemas.Metamessage]) async def get_metamessages( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: Optional[uuid.UUID] = None, + app_id: str, + user_id: str, + session_id: str, + message_id: Optional[str] = None, metamessage_type: Optional[str] = None, reverse: Optional[bool] = False, filter: Optional[str] = None, @@ -70,7 +69,7 @@ async def get_metamessages( """Get all messages for a session Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user session_id (int): The ID of the Session to retrieve @@ -110,18 +109,18 @@ async def get_metamessages( ) async def get_metamessage( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - message_id: uuid.UUID, - metamessage_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + message_id: str, + metamessage_id: str, db=db, auth=Depends(auth), ): """Get a specific Metamessage by ID Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user session_id (int): The ID of the Session to retrieve @@ -151,10 +150,10 @@ async def get_metamessage( ) async def update_metamessage( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, - metamessage_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, + metamessage_id: str, metamessage: schemas.MetamessageUpdate, db=db, auth=Depends(auth), diff --git a/src/routers/sessions.py b/src/routers/sessions.py index d232845..1a7c6ee 100644 --- a/src/routers/sessions.py +++ b/src/routers/sessions.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException @@ -19,8 +18,8 @@ @router.get("", response_model=Page[schemas.Session]) async def get_sessions( - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, is_active: Optional[bool] = False, reverse: Optional[bool] = False, filter: Optional[str] = None, @@ -30,9 +29,9 @@ async def get_sessions( """Get All Sessions for a User Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho - user_id (uuid.UUID): The User ID representing the user, managed by the user + user_id (str): The User ID representing the user, managed by the user Returns: list[schemas.Session]: List of Session objects @@ -58,8 +57,8 @@ async def get_sessions( @router.post("", response_model=schemas.Session) async def create_session( - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, session: schemas.SessionCreate, db=db, auth=Depends(auth), @@ -67,9 +66,9 @@ async def create_session( """Create a Session for a User Args: - app_id (uuid.UUID): The ID of the app representing the client + app_id (str): The ID of the app representing the client application using honcho - user_id (uuid.UUID): The User ID representing the user, managed by the user + user_id (str): The User ID representing the user, managed by the user session (schemas.SessionCreate): The Session object containing any metadata @@ -91,9 +90,9 @@ async def create_session( @router.put("/{session_id}", response_model=schemas.Session) async def update_session( - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, session: schemas.SessionUpdate, db=db, auth=Depends(auth), @@ -101,10 +100,10 @@ async def update_session( """Update the metadata of a Session Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho - user_id (uuid.UUID): The User ID representing the user, managed by the user - session_id (uuid.UUID): The ID of the Session to update + user_id (str): The User ID representing the user, managed by the user + session_id (str): The ID of the Session to update session (schemas.SessionUpdate): The Session object containing any new metadata Returns: @@ -123,19 +122,19 @@ async def update_session( @router.delete("/{session_id}") async def delete_session( - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, db=db, auth=Depends(auth), ): """Delete a session by marking it as inactive Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho - user_id (uuid.UUID): The User ID representing the user, managed by the user - session_id (uuid.UUID): The ID of the Session to delete + user_id (str): The User ID representing the user, managed by the user + session_id (str): The ID of the Session to delete Returns: dict: A message indicating that the session was deleted @@ -155,19 +154,19 @@ async def delete_session( @router.get("/{session_id}", response_model=schemas.Session) async def get_session( - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, db=db, auth=Depends(auth), ): """Get a specific session for a user by ID Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho - user_id (uuid.UUID): The User ID representing the user, managed by the user - session_id (uuid.UUID): The ID of the Session to retrieve + user_id (str): The User ID representing the user, managed by the user + session_id (str): The ID of the Session to retrieve Returns: schemas.Session: The Session object of the requested Session @@ -185,9 +184,9 @@ async def get_session( @router.post("/{session_id}/chat", response_model=schemas.AgentChat) async def chat( - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, query: schemas.AgentQuery, auth=Depends(auth), ): @@ -209,9 +208,9 @@ async def chat( }, ) async def get_chat_stream( - app_id: uuid.UUID, - user_id: uuid.UUID, - session_id: uuid.UUID, + app_id: str, + user_id: str, + session_id: str, query: schemas.AgentQuery, auth=Depends(auth), ): diff --git a/src/routers/users.py b/src/routers/users.py index c52d8b9..6809b00 100644 --- a/src/routers/users.py +++ b/src/routers/users.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request @@ -20,7 +19,7 @@ @router.post("", response_model=schemas.User) async def create_user( request: Request, - app_id: uuid.UUID, + app_id: str, user: schemas.UserCreate, db=db, auth=Depends(auth), @@ -28,7 +27,7 @@ async def create_user( """Create a User Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user (schemas.UserCreate): The User object containing any metadata @@ -48,7 +47,7 @@ async def create_user( @router.get("", response_model=Page[schemas.User]) async def get_users( request: Request, - app_id: uuid.UUID, + app_id: str, reverse: bool = False, filter: Optional[str] = None, db=db, @@ -57,7 +56,7 @@ async def get_users( """Get All Users for an App Args: - app_id (uuid.UUID): The ID of the app representing the client + app_id (str): The ID of the app representing the client application using honcho Returns: @@ -76,7 +75,7 @@ async def get_users( @router.get("/name/{name}", response_model=schemas.User) async def get_user_by_name( request: Request, - app_id: uuid.UUID, + app_id: str, name: str, db=db, auth=Depends(auth), @@ -84,7 +83,7 @@ async def get_user_by_name( """Get a User Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user @@ -101,15 +100,15 @@ async def get_user_by_name( @router.get("/{user_id}", response_model=schemas.User) async def get_user( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, db=db, auth=Depends(auth), ): """Get a User Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user @@ -125,12 +124,12 @@ async def get_user( @router.get("/get_or_create/{name}", response_model=schemas.User) async def get_or_create_user( - request: Request, app_id: uuid.UUID, name: str, db=db, auth=Depends(auth) + request: Request, app_id: str, name: str, db=db, auth=Depends(auth) ): """Get or Create a User Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user @@ -149,8 +148,8 @@ async def get_or_create_user( @router.put("/{user_id}", response_model=schemas.User) async def update_user( request: Request, - app_id: uuid.UUID, - user_id: uuid.UUID, + app_id: str, + user_id: str, user: schemas.UserUpdate, db=db, auth=Depends(auth), @@ -158,7 +157,7 @@ async def update_user( """Update a User Args: - app_id (uuid.UUID): The ID of the app representing the client application using + app_id (str): The ID of the app representing the client application using honcho user_id (str): The User ID representing the user, managed by the user user (schemas.UserCreate): The User object containing any metadata diff --git a/src/schemas.py b/src/schemas.py index c6ce033..fcab8f1 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -1,7 +1,6 @@ import datetime -import uuid -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field class AppBase(BaseModel): @@ -19,19 +18,14 @@ class AppUpdate(AppBase): class App(AppBase): - id: uuid.UUID + id: str = Field(alias="public_id") name: str - h_metadata: dict = Field(exclude=True) - metadata: dict + metadata: dict = Field(alias="h_metadata") created_at: datetime.datetime - @field_validator("metadata", mode="before") - def fetch_h_metadata(cls, value, info): - return info.data.get("h_metadata", {}) - model_config = ConfigDict( from_attributes=True, - json_schema_extra={"exclude": ["h_metadata"]}, + populate_by_name=True, ) @@ -50,20 +44,15 @@ class UserUpdate(UserBase): class User(UserBase): - id: uuid.UUID + id: str = Field(alias="public_id") name: str - app_id: uuid.UUID + app_id: str created_at: datetime.datetime - h_metadata: dict = Field(exclude=True) - metadata: dict - - @field_validator("metadata", mode="before") - def fetch_h_metadata(cls, value, info): - return info.data.get("h_metadata", {}) + metadata: dict = Field(alias="h_metadata") model_config = ConfigDict( from_attributes=True, - json_schema_extra={"exclude": ["h_metadata"]}, + populate_by_name=True, ) @@ -82,21 +71,16 @@ class MessageUpdate(MessageBase): class Message(MessageBase): + id: str = Field(alias="public_id") content: str is_user: bool - session_id: uuid.UUID - id: uuid.UUID - h_metadata: dict = Field(exclude=True) - metadata: dict + session_id: str + metadata: dict = Field(alias="h_metadata") created_at: datetime.datetime - @field_validator("metadata", mode="before") - def fetch_h_metadata(cls, value, info): - return info.data.get("h_metadata", {}) - model_config = ConfigDict( from_attributes=True, - json_schema_extra={"exclude": ["h_metadata"]}, + populate_by_name=True, ) @@ -113,21 +97,16 @@ class SessionUpdate(SessionBase): class Session(SessionBase): - id: uuid.UUID + id: str = Field(alias="public_id") # messages: list[Message] is_active: bool - user_id: uuid.UUID - h_metadata: dict = Field(exclude=True) - metadata: dict + user_id: str + metadata: dict = Field(alias="h_metadata") created_at: datetime.datetime - @field_validator("metadata", mode="before") - def fetch_h_metadata(cls, value, info): - return info.data.get("h_metadata", {}) - model_config = ConfigDict( from_attributes=True, - json_schema_extra={"exclude": ["h_metadata"]}, + populate_by_name=True, ) @@ -138,31 +117,27 @@ class MetamessageBase(BaseModel): class MetamessageCreate(MetamessageBase): metamessage_type: str content: str - message_id: uuid.UUID + message_id: str metadata: dict | None = {} class MetamessageUpdate(MetamessageBase): - message_id: uuid.UUID + message_id: str metamessage_type: str | None = None metadata: dict | None = None class Metamessage(MetamessageBase): + id: str = Field(alias="public_id") metamessage_type: str content: str - id: uuid.UUID - message_id: uuid.UUID - h_metadata: dict = Field(exclude=True) - metadata: dict + message_id: str + metadata: dict = Field(alias="h_metadata") created_at: datetime.datetime - @field_validator("metadata", mode="before") - def fetch_h_metadata(cls, value, info): - return info.data.get("h_metadata", {}) - model_config = ConfigDict( from_attributes=True, + populate_by_name=True, json_schema_extra={"exclude": ["h_metadata"]}, ) @@ -182,20 +157,15 @@ class CollectionUpdate(CollectionBase): class Collection(CollectionBase): - id: uuid.UUID + id: str = Field(alias="public_id") name: str - user_id: uuid.UUID - h_metadata: dict = Field(exclude=True) - metadata: dict + user_id: str + metadata: dict = Field(alias="h_metadata") created_at: datetime.datetime - @field_validator("metadata", mode="before") - def fetch_h_metadata(cls, value, info): - return info.data.get("h_metadata", {}) - model_config = ConfigDict( from_attributes=True, - json_schema_extra={"exclude": ["h_metadata"]}, + populate_by_name=True, ) @@ -214,20 +184,15 @@ class DocumentUpdate(DocumentBase): class Document(DocumentBase): - id: uuid.UUID + id: str = Field(alias="public_id") content: str - h_metadata: dict = Field(exclude=True) - metadata: dict + metadata: dict = Field(alias="h_metadata") created_at: datetime.datetime - collection_id: uuid.UUID - - @field_validator("metadata", mode="before") - def fetch_h_metadata(cls, value, info): - return info.data.get("h_metadata", {}) + collection_id: str model_config = ConfigDict( from_attributes=True, - json_schema_extra={"exclude": ["h_metadata"]}, + populate_by_name=True, ) diff --git a/tests/conftest.py b/tests/conftest.py index 279a670..269beed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import logging # noqa: I001 import os import sys -import uuid +from nanoid import generate as generate_nanoid import pytest import pytest_asyncio @@ -125,12 +125,12 @@ async def override_get_db(): async def sample_data(db_session): """Helper function to create test data""" # Create test app - test_app = models.App(name=str(uuid.uuid4())) + test_app = models.App(name=str(generate_nanoid())) db_session.add(test_app) await db_session.flush() # Create test user - test_user = models.User(name=str(uuid.uuid4()), app_id=test_app.id) + test_user = models.User(name=str(generate_nanoid()), app_id=test_app.id) db_session.add(test_user) await db_session.flush() diff --git a/tests/routes/test_apps.py b/tests/routes/test_apps.py index 3588e97..6b34bc1 100644 --- a/tests/routes/test_apps.py +++ b/tests/routes/test_apps.py @@ -1,4 +1,4 @@ -import uuid +from nanoid import generate as generate_nanoid import pytest from sqlalchemy.ext.asyncio import AsyncSession @@ -7,7 +7,7 @@ def test_create_app(client): - name = str(uuid.uuid4()) + name = str(generate_nanoid()) response = client.post("/apps", json={"name": name, "metadata": {"key": "value"}}) print(response) assert response.status_code == 200 @@ -18,7 +18,7 @@ def test_create_app(client): def test_get_or_create_app(client): - name = str(uuid.uuid4()) + name = str(generate_nanoid()) response = client.get(f"/apps/name/{name}") assert response.status_code == 404 response = client.get(f"/apps/get_or_create/{name}") @@ -29,7 +29,7 @@ def test_get_or_create_app(client): def test_get_or_create_existing_app(client): - name = str(uuid.uuid4()) + name = str(generate_nanoid()) response = client.get(f"/apps/name/{name}") assert response.status_code == 404 response = client.post("/apps", json={"name": name, "metadata": {"key": "value"}}) @@ -63,7 +63,7 @@ def test_get_app_by_name(client, sample_data): def test_update_app(client, sample_data): test_app, _ = sample_data - new_name = str(uuid.uuid4()) + new_name = str(generate_nanoid()) response = client.put( f"/apps/{test_app.id}", json={"name": new_name, "metadata": {"new_key": "new_value"}}, diff --git a/tests/routes/test_users.py b/tests/routes/test_users.py index 46bdc19..ad08d58 100644 --- a/tests/routes/test_users.py +++ b/tests/routes/test_users.py @@ -1,9 +1,9 @@ -import uuid +from nanoid import generate as generate_nanoid def test_create_user(client, sample_data): test_app, _ = sample_data - name = str(uuid.uuid4()) + name = str(generate_nanoid()) response = client.post( f"/apps/{test_app.id}/users", json={"name": name, "metadata": {"user_key": "user_value"}}, @@ -35,7 +35,7 @@ def test_get_user_by_name(client, sample_data): def test_get_or_create_user(client, sample_data): test_app, _ = sample_data - name = str(uuid.uuid4()) + name = str(generate_nanoid()) response = client.get(f"/apps/{test_app.id}/users/name/{name}") assert response.status_code == 404 response = client.get(f"/apps/{test_app.id}/users/get_or_create/{name}") @@ -56,7 +56,7 @@ def test_get_or_create_user(client, sample_data): def test_update_user(client, sample_data): test_app, test_user = sample_data - new_name = str(uuid.uuid4()) + new_name = str(generate_nanoid()) response = client.put( f"/apps/{test_app.id}/users/{test_user.id}", json={"name": new_name, "metadata": {"new_key": "new_value"}}, diff --git a/uv.lock b/uv.lock index 6c3ecb0..8d12779 100644 --- a/uv.lock +++ b/uv.lock @@ -463,6 +463,7 @@ dependencies = [ { name = "greenlet" }, { name = "httpx" }, { name = "mirascope" }, + { name = "nanoid" }, { name = "openai" }, { name = "opentelemetry-exporter-otlp" }, { name = "opentelemetry-instrumentation-fastapi" }, @@ -493,6 +494,7 @@ requires-dist = [ { name = "greenlet", specifier = ">=3.0.3" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "mirascope", specifier = ">=0.18.0" }, + { name = "nanoid", specifier = ">=2.0.0" }, { name = "openai", specifier = ">=1.43.0" }, { name = "opentelemetry-exporter-otlp", specifier = ">=1.24.0" }, { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.45b0" }, @@ -780,6 +782,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/35/ba/024ae6eeb9593d3fd7b6c47852800c3cd60fceb496c846dc1b7e20d68bd3/mirascope-0.18.3-py3-none-any.whl", hash = "sha256:2abc00f55feec29295ffd09a54a6403768a1c78085612e8caee0379cba26ea19", size = 113671 }, ] +[[package]] +name = "nanoid" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/9d/0250bf5935d88e214df469d35eccc0f6ff7e9db046fc8a9aeb4b2a192775/nanoid-2.0.0.tar.gz", hash = "sha256:5a80cad5e9c6e9ae3a41fa2fb34ae189f7cb420b2a5d8f82bd9d23466e4efa68", size = 3290 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/0d/8630f13998638dc01e187fadd2e5c6d42d127d08aeb4943d231664d6e539/nanoid-2.0.0-py3-none-any.whl", hash = "sha256:90aefa650e328cffb0893bbd4c236cfd44c48bc1f2d0b525ecc53c3187b653bb", size = 5844 }, +] + [[package]] name = "numpy" version = "2.0.2" From a61c7a1ef9720b02e281dfd015456c9bca366b82 Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Wed, 16 Oct 2024 17:19:54 -0400 Subject: [PATCH 3/5] feat(pydantic) Updates models and routes to get tests passing --- pyproject.toml | 3 +- src/models.py | 16 ++--- src/schemas.py | 116 ++++++++++++++++++++++++------ tests/conftest.py | 2 +- tests/routes/test_apps.py | 11 +-- tests/routes/test_collections.py | 20 +++--- tests/routes/test_documents.py | 24 +++---- tests/routes/test_messages.py | 16 ++--- tests/routes/test_metamessages.py | 33 +++++---- tests/routes/test_sessions.py | 18 ++--- tests/routes/test_users.py | 18 ++--- 11 files changed, 178 insertions(+), 99 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bda1cc6..705f219 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,5 +57,6 @@ ignore = ["E501"] [tool.ruff.flake8-bugbear] extend-immutable-calls = ["fastapi.Depends"] -[tool.lpytest.ini_options] +[tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "session" diff --git a/src/models.py b/src/models.py index 4fcd305..2c4acb5 100644 --- a/src/models.py +++ b/src/models.py @@ -34,7 +34,7 @@ class App(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), index=True, default=func.now() ) - h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) class User(Base): @@ -46,7 +46,7 @@ class User(Base): String(21), index=True, unique=True, default=generate_nanoid ) name: Mapped[str] = mapped_column(String(512), index=True) - h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), index=True, default=func.now() ) @@ -58,7 +58,7 @@ class User(Base): __table_args__ = (UniqueConstraint("name", "app_id", name="unique_name_app_user"),) def __repr__(self) -> str: - return f"User(id={self.id}, app_id={self.app_id}, user_id={self.user_id}, created_at={self.created_at}, h_metadata={self.h_metadata})" + return f"User(id={self.id}, app_id={self.app_id}, user_id={self.id}, created_at={self.created_at}, h_metadata={self.h_metadata})" class Session(Base): @@ -70,7 +70,7 @@ class Session(Base): String(21), index=True, unique=True, default=generate_nanoid ) is_active: Mapped[bool] = mapped_column(default=True) - h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), index=True, default=func.now() ) @@ -95,7 +95,7 @@ class Message(Base): ) is_user: Mapped[bool] content: Mapped[str] = mapped_column(String(65535)) - h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), index=True, default=func.now() @@ -125,7 +125,7 @@ class Metamessage(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), index=True, default=func.now() ) - h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) def __repr__(self) -> str: return f"Metamessages(id={self.id}, message_id={self.message_id}, metamessage_type={self.metamessage_type}, content={self.content[10:]})" @@ -144,7 +144,7 @@ class Collection(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), index=True, default=func.now() ) - h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) documents = relationship( "Document", back_populates="collection", cascade="all, delete, delete-orphan" ) @@ -166,7 +166,7 @@ class Document(Base): public_id: Mapped[str] = mapped_column( String(21), index=True, unique=True, default=generate_nanoid ) - h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) content: Mapped[str] = mapped_column(String(65535)) embedding = mapped_column(Vector(1536)) created_at: Mapped[datetime.datetime] = mapped_column( diff --git a/src/schemas.py b/src/schemas.py index fcab8f1..0f7f51e 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -1,6 +1,6 @@ import datetime -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator class AppBase(BaseModel): @@ -18,14 +18,24 @@ class AppUpdate(AppBase): class App(AppBase): - id: str = Field(alias="public_id") + public_id: str = Field(exclude=True) + id: str name: str - metadata: dict = Field(alias="h_metadata") + h_metadata: dict = Field(exclude=True) + metadata: dict created_at: datetime.datetime + @field_validator("metadata", mode="before") + def fetch_h_metadata(cls, value, info): + return info.data.get("h_metadata", {}) + + @field_validator("id", mode="before") + def internal_to_public(cls, value, info): + return info.data.get("public_id", {}) + model_config = ConfigDict( from_attributes=True, - populate_by_name=True, + json_schema_extra={"exclude": ["h_metadata", "public_id"]}, ) @@ -44,15 +54,25 @@ class UserUpdate(UserBase): class User(UserBase): - id: str = Field(alias="public_id") + public_id: str = Field(exclude=True) + id: str name: str app_id: str created_at: datetime.datetime - metadata: dict = Field(alias="h_metadata") + h_metadata: dict = Field(exclude=True) + metadata: dict + + @field_validator("metadata", mode="before") + def fetch_h_metadata(cls, value, info): + return info.data.get("h_metadata", {}) + + @field_validator("id", mode="before") + def internal_to_public(cls, value, info): + return info.data.get("public_id", {}) model_config = ConfigDict( from_attributes=True, - populate_by_name=True, + json_schema_extra={"exclude": ["h_metadata", "public_id"]}, ) @@ -71,16 +91,26 @@ class MessageUpdate(MessageBase): class Message(MessageBase): - id: str = Field(alias="public_id") + public_id: str = Field(exclude=True) + id: str content: str is_user: bool session_id: str - metadata: dict = Field(alias="h_metadata") + h_metadata: dict = Field(exclude=True) + metadata: dict created_at: datetime.datetime + @field_validator("metadata", mode="before") + def fetch_h_metadata(cls, value, info): + return info.data.get("h_metadata", {}) + + @field_validator("id", mode="before") + def internal_to_public(cls, value, info): + return info.data.get("public_id", {}) + model_config = ConfigDict( from_attributes=True, - populate_by_name=True, + json_schema_extra={"exclude": ["h_metadata", "public_id"]}, ) @@ -97,16 +127,27 @@ class SessionUpdate(SessionBase): class Session(SessionBase): - id: str = Field(alias="public_id") + public_id: str = Field(exclude=True) + id: str # messages: list[Message] is_active: bool user_id: str - metadata: dict = Field(alias="h_metadata") + h_metadata: dict = Field(exclude=True) + metadata: dict + created_at: datetime.datetime + @field_validator("metadata", mode="before") + def fetch_h_metadata(cls, value, info): + return info.data.get("h_metadata", {}) + + @field_validator("id", mode="before") + def internal_to_public(cls, value, info): + return info.data.get("public_id", {}) + model_config = ConfigDict( from_attributes=True, - populate_by_name=True, + json_schema_extra={"exclude": ["h_metadata", "public_id"]}, ) @@ -128,17 +169,26 @@ class MetamessageUpdate(MetamessageBase): class Metamessage(MetamessageBase): - id: str = Field(alias="public_id") + public_id: str = Field(exclude=True) + id: str metamessage_type: str content: str message_id: str - metadata: dict = Field(alias="h_metadata") + h_metadata: dict = Field(exclude=True) + metadata: dict created_at: datetime.datetime + @field_validator("metadata", mode="before") + def fetch_h_metadata(cls, value, info): + return info.data.get("h_metadata", {}) + + @field_validator("id", mode="before") + def internal_to_public(cls, value, info): + return info.data.get("public_id", {}) + model_config = ConfigDict( from_attributes=True, - populate_by_name=True, - json_schema_extra={"exclude": ["h_metadata"]}, + json_schema_extra={"exclude": ["h_metadata", "public_id"]}, ) @@ -157,15 +207,25 @@ class CollectionUpdate(CollectionBase): class Collection(CollectionBase): - id: str = Field(alias="public_id") + public_id: str = Field(exclude=True) + id: str name: str user_id: str - metadata: dict = Field(alias="h_metadata") + h_metadata: dict = Field(exclude=True) + metadata: dict created_at: datetime.datetime + @field_validator("metadata", mode="before") + def fetch_h_metadata(cls, value, info): + return info.data.get("h_metadata", {}) + + @field_validator("id", mode="before") + def internal_to_public(cls, value, info): + return info.data.get("public_id", {}) + model_config = ConfigDict( from_attributes=True, - populate_by_name=True, + json_schema_extra={"exclude": ["h_metadata", "public_id"]}, ) @@ -184,15 +244,25 @@ class DocumentUpdate(DocumentBase): class Document(DocumentBase): - id: str = Field(alias="public_id") + public_id: str = Field(exclude=True) + id: str content: str - metadata: dict = Field(alias="h_metadata") + h_metadata: dict = Field(exclude=True) + metadata: dict created_at: datetime.datetime collection_id: str + @field_validator("metadata", mode="before") + def fetch_h_metadata(cls, value, info): + return info.data.get("h_metadata", {}) + + @field_validator("id", mode="before") + def internal_to_public(cls, value, info): + return info.data.get("public_id", {}) + model_config = ConfigDict( from_attributes=True, - populate_by_name=True, + json_schema_extra={"exclude": ["h_metadata", "public_id"]}, ) diff --git a/tests/conftest.py b/tests/conftest.py index 269beed..34882f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,7 +130,7 @@ async def sample_data(db_session): await db_session.flush() # Create test user - test_user = models.User(name=str(generate_nanoid()), app_id=test_app.id) + test_user = models.User(name=str(generate_nanoid()), app_id=test_app.public_id) db_session.add(test_user) await db_session.flush() diff --git a/tests/routes/test_apps.py b/tests/routes/test_apps.py index 6b34bc1..1c8b165 100644 --- a/tests/routes/test_apps.py +++ b/tests/routes/test_apps.py @@ -12,6 +12,9 @@ def test_create_app(client): print(response) assert response.status_code == 200 data = response.json() + print("===================") + print(data) + print("===================") assert data["name"] == name assert data["metadata"] == {"key": "value"} assert "id" in data @@ -45,11 +48,11 @@ def test_get_or_create_existing_app(client): def test_get_app_by_id(client, sample_data): test_app, _ = sample_data - response = client.get(f"/apps/{test_app.id}") + response = client.get(f"/apps/{test_app.public_id}") assert response.status_code == 200 data = response.json() assert data["name"] == test_app.name - assert data["id"] == str(test_app.id) + assert data["id"] == str(test_app.public_id) def test_get_app_by_name(client, sample_data): @@ -58,14 +61,14 @@ def test_get_app_by_name(client, sample_data): assert response.status_code == 200 data = response.json() assert data["name"] == test_app.name - assert data["id"] == str(test_app.id) + assert data["id"] == str(test_app.public_id) def test_update_app(client, sample_data): test_app, _ = sample_data new_name = str(generate_nanoid()) response = client.put( - f"/apps/{test_app.id}", + f"/apps/{test_app.public_id}", json={"name": new_name, "metadata": {"new_key": "new_value"}}, ) assert response.status_code == 200 diff --git a/tests/routes/test_collections.py b/tests/routes/test_collections.py index bb59d6b..af1b570 100644 --- a/tests/routes/test_collections.py +++ b/tests/routes/test_collections.py @@ -4,7 +4,7 @@ def test_create_collection(client, sample_data) -> None: test_app, test_user = sample_data response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 @@ -18,14 +18,14 @@ def test_get_collection_by_id(client, sample_data) -> None: test_app, test_user = sample_data # Make the collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 data = response.json() # Get the collection response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{data['id']}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{data['id']}" ) assert response.status_code == 200 data = response.json() @@ -38,14 +38,14 @@ def test_get_collection_by_name(client, sample_data) -> None: test_app, test_user = sample_data # Make the collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 data = response.json() # Get the collection response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/collections/name/test_collection" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/name/test_collection" ) assert response.status_code == 200 data = response.json() @@ -58,14 +58,14 @@ def test_update_collection(client, sample_data) -> None: test_app, test_user = sample_data # Make the collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 data = response.json() # Update the collection response = client.put( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{data['id']}", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{data['id']}", json={"name": "test_collection_updated", "metadata": {"new_key": "new_value"}}, ) assert response.status_code == 200 @@ -79,17 +79,17 @@ def test_delete_collection(client, sample_data) -> None: test_app, test_user = sample_data # Make the collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 data = response.json() # Delete the collection response = client.delete( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{data['id']}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{data['id']}" ) assert response.status_code == 200 response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{data['id']}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{data['id']}" ) assert response.status_code == 404 diff --git a/tests/routes/test_documents.py b/tests/routes/test_documents.py index c9ac0e5..5b701a6 100644 --- a/tests/routes/test_documents.py +++ b/tests/routes/test_documents.py @@ -2,14 +2,14 @@ def test_create_document(client, sample_data): test_app, test_user = sample_data # Create a collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 data = response.json() # Create a document response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{data['id']}/documents", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{data['id']}/documents", json={"content": "test_text", "metadata": {}}, ) assert response.status_code == 200 @@ -23,21 +23,21 @@ def test_get_document(client, sample_data): test_app, test_user = sample_data # Create a collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 collection = response.json() # Create a document response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{collection['id']}/documents", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{collection['id']}/documents", json={"content": "test_text", "metadata": {}}, ) assert response.status_code == 200 document = response.json() # Get the document response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{collection['id']}/documents/{document['id']}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{collection['id']}/documents/{document['id']}" ) assert response.status_code == 200 data = response.json() @@ -50,21 +50,21 @@ def test_update_document(client, sample_data): test_app, test_user = sample_data # Create a collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 data = response.json() # Create a document response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{data['id']}/documents", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{data['id']}/documents", json={"content": "test_text", "metadata": {}}, ) assert response.status_code == 200 data = response.json() # Update the document response = client.put( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{data['id']}/documents/{data['id']}", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{data['id']}/documents/{data['id']}", json={"content": "test_text_updated", "metadata": {"new_key": "new_value"}}, ) @@ -73,24 +73,24 @@ def test_delete_document(client, sample_data): test_app, test_user = sample_data # Create a collection response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections", json={"name": "test_collection", "metadata": {}}, ) assert response.status_code == 200 collection = response.json() # Create a document response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{collection['id']}/documents", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{collection['id']}/documents", json={"content": "test_text", "metadata": {}}, ) assert response.status_code == 200 document = response.json() # Delete the document response = client.delete( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{collection['id']}/documents/{document['id']}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{collection['id']}/documents/{document['id']}" ) assert response.status_code == 200 response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/collections/{collection['id']}/documents/{document['id']}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{collection['id']}/documents/{document['id']}" ) assert response.status_code == 404 diff --git a/tests/routes/test_messages.py b/tests/routes/test_messages.py index 000207d..d66237c 100644 --- a/tests/routes/test_messages.py +++ b/tests/routes/test_messages.py @@ -7,12 +7,12 @@ async def test_create_message(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id) + test_session = models.Session(user_id=test_user.public_id) db_session.add(test_session) await db_session.commit() response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}/messages", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}/messages", json={ "content": "Test message", "is_user": True, @@ -31,17 +31,17 @@ async def test_create_message(client, db_session, sample_data): async def test_get_messages(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session and message - test_session = models.Session(user_id=test_user.id) + test_session = models.Session(user_id=test_user.public_id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True + session_id=test_session.public_id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}/messages" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}/messages" ) assert response.status_code == 200 data = response.json() @@ -56,17 +56,17 @@ async def test_get_messages(client, db_session, sample_data): async def test_update_message(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session and message - test_session = models.Session(user_id=test_user.id) + test_session = models.Session(user_id=test_user.public_id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True + session_id=test_session.public_id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() response = client.put( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}/messages/{test_message.id}", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}/messages/{test_message.public_id}", json={"metadata": {"new_key": "new_value"}}, ) assert response.status_code == 200 diff --git a/tests/routes/test_metamessages.py b/tests/routes/test_metamessages.py index 2d7691b..b0b6618 100644 --- a/tests/routes/test_metamessages.py +++ b/tests/routes/test_metamessages.py @@ -7,19 +7,19 @@ async def test_create_metamessage(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id) + test_session = models.Session(user_id=test_user.public_id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True + session_id=test_session.public_id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}/metamessages", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}/metamessages", json={ - "message_id": str(test_message.id), + "message_id": str(test_message.public_id), "content": "Test Metamessage", "metadata": {}, "metamessage_type": "test_type", @@ -27,7 +27,7 @@ async def test_create_metamessage(client, db_session, sample_data): ) assert response.status_code == 200 data = response.json() - assert data["message_id"] == str(test_message.id) + assert data["message_id"] == str(test_message.public_id) assert data["content"] == "Test Metamessage" assert data["metadata"] == {} assert data["metamessage_type"] == "test_type" @@ -37,16 +37,16 @@ async def test_create_metamessage(client, db_session, sample_data): async def test_get_metamessage(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id) + test_session = models.Session(user_id=test_user.public_id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True + session_id=test_session.public_id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() test_metamessage = models.Metamessage( - message_id=test_message.id, + message_id=test_message.public_id, content="Test Metamessage", metadata={}, metamessage_type="test_type", @@ -55,11 +55,11 @@ async def test_get_metamessage(client, db_session, sample_data): await db_session.commit() response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}/metamessages/{test_metamessage.id}?message_id={test_message.id}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}/metamessages/{test_metamessage.public_id}?message_id={test_message.public_id}" ) assert response.status_code == 200 data = response.json() - assert data["message_id"] == str(test_message.id) + assert data["message_id"] == str(test_message.public_id) assert data["content"] == "Test Metamessage" assert data["metadata"] == {} assert data["metamessage_type"] == "test_type" @@ -69,16 +69,16 @@ async def test_get_metamessage(client, db_session, sample_data): async def test_update_metamessage(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id) + test_session = models.Session(user_id=test_user.public_id) db_session.add(test_session) await db_session.commit() test_message = models.Message( - session_id=test_session.id, content="Test message", is_user=True + session_id=test_session.public_id, content="Test message", is_user=True ) db_session.add(test_message) await db_session.commit() test_metamessage = models.Metamessage( - message_id=test_message.id, + message_id=test_message.public_id, content="Test Metamessage", metadata={}, metamessage_type="test_type", @@ -87,8 +87,11 @@ async def test_update_metamessage(client, db_session, sample_data): await db_session.commit() response = client.put( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}/metamessages/{test_metamessage.id}", - json={"message_id": str(test_message.id), "metadata": {"new_key": "new_value"}}, + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}/metamessages/{test_metamessage.public_id}", + json={ + "message_id": str(test_message.public_id), + "metadata": {"new_key": "new_value"}, + }, ) assert response.status_code == 200 data = response.json() diff --git a/tests/routes/test_sessions.py b/tests/routes/test_sessions.py index 33da903..538c2ef 100644 --- a/tests/routes/test_sessions.py +++ b/tests/routes/test_sessions.py @@ -6,7 +6,7 @@ def test_create_session(client, sample_data): test_app, test_user = sample_data response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/sessions", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions", json={ "metadata": {"session_key": "session_value"}, }, @@ -22,7 +22,7 @@ async def test_get_sessions(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session response = client.post( - f"/apps/{test_app.id}/users/{test_user.id}/sessions", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions", json={ "metadata": {"test_key": "test_value"}, }, @@ -32,7 +32,9 @@ async def test_get_sessions(client, db_session, sample_data): assert data["metadata"] == {"test_key": "test_value"} assert "id" in data - response = client.get(f"/apps/{test_app.id}/users/{test_user.id}/sessions") + response = client.get( + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions" + ) assert response.status_code == 200 data = response.json() assert "items" in data @@ -44,12 +46,12 @@ async def test_get_sessions(client, db_session, sample_data): async def test_update_session(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.public_id, metadata={}) db_session.add(test_session) await db_session.commit() response = client.put( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}", + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}", json={"metadata": {"new_key": "new_value"}}, ) assert response.status_code == 200 @@ -61,15 +63,15 @@ async def test_update_session(client, db_session, sample_data): async def test_delete_session(client, db_session, sample_data): test_app, test_user = sample_data # Create a test session - test_session = models.Session(user_id=test_user.id, metadata={}) + test_session = models.Session(user_id=test_user.public_id, metadata={}) db_session.add(test_session) await db_session.commit() response = client.delete( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}" ) assert response.status_code == 200 response = client.get( - f"/apps/{test_app.id}/users/{test_user.id}/sessions/{test_session.id}" + f"/apps/{test_app.public_id}/users/{test_user.public_id}/sessions/{test_session.public_id}" ) data = response.json() assert data["is_active"] is False diff --git a/tests/routes/test_users.py b/tests/routes/test_users.py index ad08d58..5bd17bb 100644 --- a/tests/routes/test_users.py +++ b/tests/routes/test_users.py @@ -5,7 +5,7 @@ def test_create_user(client, sample_data): test_app, _ = sample_data name = str(generate_nanoid()) response = client.post( - f"/apps/{test_app.id}/users", + f"/apps/{test_app.public_id}/users", json={"name": name, "metadata": {"user_key": "user_value"}}, ) assert response.status_code == 200 @@ -17,28 +17,28 @@ def test_create_user(client, sample_data): def test_get_user_by_id(client, sample_data): test_app, test_user = sample_data - response = client.get(f"/apps/{test_app.id}/users/{test_user.id}") + response = client.get(f"/apps/{test_app.public_id}/users/{test_user.public_id}") assert response.status_code == 200 data = response.json() assert data["name"] == test_user.name - assert data["id"] == str(test_user.id) + assert data["id"] == str(test_user.public_id) def test_get_user_by_name(client, sample_data): test_app, test_user = sample_data - response = client.get(f"/apps/{test_app.id}/users/name/{test_user.name}") + response = client.get(f"/apps/{test_app.public_id}/users/name/{test_user.name}") assert response.status_code == 200 data = response.json() assert data["name"] == test_user.name - assert data["id"] == str(test_user.id) + assert data["id"] == str(test_user.public_id) def test_get_or_create_user(client, sample_data): test_app, _ = sample_data name = str(generate_nanoid()) - response = client.get(f"/apps/{test_app.id}/users/name/{name}") + response = client.get(f"/apps/{test_app.public_id}/users/name/{name}") assert response.status_code == 404 - response = client.get(f"/apps/{test_app.id}/users/get_or_create/{name}") + response = client.get(f"/apps/{test_app.public_id}/users/get_or_create/{name}") assert response.status_code == 200 data = response.json() assert data["name"] == name @@ -47,7 +47,7 @@ def test_get_or_create_user(client, sample_data): # def test_get_users(client, sample_data): # test_app, _ = sample_data -# response = client.get(f"/apps/{test_app.id}/users") +# response = client.get(f"/apps/{test_app.public_id}/users") # assert response.status_code == 200 # data = response.json() # assert "items" in data @@ -58,7 +58,7 @@ def test_update_user(client, sample_data): test_app, test_user = sample_data new_name = str(generate_nanoid()) response = client.put( - f"/apps/{test_app.id}/users/{test_user.id}", + f"/apps/{test_app.public_id}/users/{test_user.public_id}", json={"name": new_name, "metadata": {"new_key": "new_value"}}, ) assert response.status_code == 200 From 2bc5c743d6869c8654a9bfd53c2dc9374a2af46f Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Wed, 16 Oct 2024 17:52:46 -0400 Subject: [PATCH 4/5] fix(deriver) fix enqueue method --- src/routers/messages.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/routers/messages.py b/src/routers/messages.py index 98625e7..7e458da 100644 --- a/src/routers/messages.py +++ b/src/routers/messages.py @@ -42,9 +42,7 @@ async def enqueue(payload: dict): processed_payload = { k: str(v) if isinstance(v, str) else v for k, v in payload.items() } - item = QueueItem( - payload=processed_payload, session_id=payload["session_id"] - ) + item = QueueItem(payload=processed_payload, session_id=session.id) db.add(item) await db.commit() return From c65a0f48fcd1504eb35acb7b1a3770d1dcacc62e Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Thu, 17 Oct 2024 12:40:16 -0400 Subject: [PATCH 5/5] Perf Testing Utilities --- .python-version | 1 + pyproject.toml | 1 + uv.lock | 16 ++++++++++++++++ 3 files changed, 18 insertions(+) create mode 100644 .python-version diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..2c07333 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/pyproject.toml b/pyproject.toml index 705f219..7286e10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dev-dependencies = [ "pytest-asyncio>=0.23.7", "coverage>=7.6.0", "interrogate>=1.7.0", + "py-spy>=0.3.14", ] [tool.ruff.lint] diff --git a/uv.lock b/uv.lock index 8d12779..bbd4596 100644 --- a/uv.lock +++ b/uv.lock @@ -482,6 +482,7 @@ dependencies = [ dev = [ { name = "coverage" }, { name = "interrogate" }, + { name = "py-spy" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "sqlalchemy-utils" }, @@ -513,6 +514,7 @@ requires-dist = [ dev = [ { name = "coverage", specifier = ">=7.6.0" }, { name = "interrogate", specifier = ">=1.7.0" }, + { name = "py-spy", specifier = ">=0.3.14" }, { name = "pytest", specifier = ">=8.2.2" }, { name = "pytest-asyncio", specifier = ">=0.23.7" }, { name = "sqlalchemy-utils", specifier = ">=0.41.2" }, @@ -1181,6 +1183,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", size = 98708 }, ] +[[package]] +name = "py-spy" +version = "0.3.14" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/ff/9f10044630e55eb6fca1c32a00710901be9410b71031711fae6690acc643/py_spy-0.3.14-py2.py3-none-macosx_10_7_x86_64.whl", hash = "sha256:5b342cc5feb8d160d57a7ff308de153f6be68dcf506ad02b4d67065f2bae7f45", size = 1576751 }, + { url = "https://files.pythonhosted.org/packages/4c/f3/ace9005f101cb7d41bd69081ea4d095950a31bb6df8a26cf142928e7658f/py_spy-0.3.14-py2.py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fe7efe6c91f723442259d428bf1f9ddb9c1679828866b353d539345ca40d9dd2", size = 3049462 }, + { url = "https://files.pythonhosted.org/packages/5e/53/404550ee909148afbffd9c93723001f895656a6be1ecb57598c12f08e429/py_spy-0.3.14-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590905447241d789d9de36cff9f52067b6f18d8b5e9fb399242041568d414461", size = 2578909 }, + { url = "https://files.pythonhosted.org/packages/3d/14/47fa8b0cb7e9c95117682dfcfaf6245256553bdff7ca83ba4b89db226858/py_spy-0.3.14-py2.py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd6211fe7f587b3532ba9d300784326d9a6f2b890af7bf6fff21a029ebbc812b", size = 2703439 }, + { url = "https://files.pythonhosted.org/packages/e6/6b/d49bc425ab0369f8f2661af047202fb45f7aa804b971db461775b035c8db/py_spy-0.3.14-py2.py3-none-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3e8e48032e71c94c3dd51694c39e762e4bbfec250df5bf514adcdd64e79371e0", size = 2709489 }, + { url = "https://files.pythonhosted.org/packages/c2/d2/082de8db2285a652a00a39f2bcffaaf0b0c9c378f4830bb5983d2600b2dd/py_spy-0.3.14-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:f59b0b52e56ba9566305236375e6fc68888261d0d36b5addbe3cf85affbefc0e", size = 3012916 }, + { url = "https://files.pythonhosted.org/packages/9d/eb/1749a892ada87c65320503d8be2aa31ba1e034cc818aead54ec45d347719/py_spy-0.3.14-py2.py3-none-win_amd64.whl", hash = "sha256:8f5b311d09f3a8e33dbd0d44fc6e37b715e8e0c7efefafcda8bfd63b31ab5a31", size = 1446535 }, +] + [[package]] name = "pydantic" version = "2.9.1"