Skip to content

Commit

Permalink
feat(pydantic) Updates models and routes to get tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
VVoruganti committed Oct 16, 2024
1 parent fe49dfc commit a61c7a1
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 99 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,6 @@ ignore = ["E501"]
[tool.ruff.flake8-bugbear]
extend-immutable-calls = ["fastapi.Depends"]

[tool.lpytest.ini_options]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
16 changes: 8 additions & 8 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class App(Base):
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), index=True, default=func.now()
)
h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={})
h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={})


class User(Base):
Expand All @@ -46,7 +46,7 @@ class User(Base):
String(21), index=True, unique=True, default=generate_nanoid
)
name: Mapped[str] = mapped_column(String(512), index=True)
h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={})
h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={})
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), index=True, default=func.now()
)
Expand All @@ -58,7 +58,7 @@ class User(Base):
__table_args__ = (UniqueConstraint("name", "app_id", name="unique_name_app_user"),)

def __repr__(self) -> str:
return f"User(id={self.id}, app_id={self.app_id}, user_id={self.user_id}, created_at={self.created_at}, h_metadata={self.h_metadata})"
return f"User(id={self.id}, app_id={self.app_id}, user_id={self.id}, created_at={self.created_at}, h_metadata={self.h_metadata})"


class Session(Base):
Expand All @@ -70,7 +70,7 @@ class Session(Base):
String(21), index=True, unique=True, default=generate_nanoid
)
is_active: Mapped[bool] = mapped_column(default=True)
h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={})
h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={})
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), index=True, default=func.now()
)
Expand All @@ -95,7 +95,7 @@ class Message(Base):
)
is_user: Mapped[bool]
content: Mapped[str] = mapped_column(String(65535))
h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={})
h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={})

created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), index=True, default=func.now()
Expand Down Expand Up @@ -125,7 +125,7 @@ class Metamessage(Base):
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), index=True, default=func.now()
)
h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={})
h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={})

def __repr__(self) -> str:
return f"Metamessages(id={self.id}, message_id={self.message_id}, metamessage_type={self.metamessage_type}, content={self.content[10:]})"
Expand All @@ -144,7 +144,7 @@ class Collection(Base):
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), index=True, default=func.now()
)
h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={})
h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={})
documents = relationship(
"Document", back_populates="collection", cascade="all, delete, delete-orphan"
)
Expand All @@ -166,7 +166,7 @@ class Document(Base):
public_id: Mapped[str] = mapped_column(
String(21), index=True, unique=True, default=generate_nanoid
)
h_metadata: Mapped[dict] = mapped_column("h_metadata", JSONB, default={})
h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={})
content: Mapped[str] = mapped_column(String(65535))
embedding = mapped_column(Vector(1536))
created_at: Mapped[datetime.datetime] = mapped_column(
Expand Down
116 changes: 93 additions & 23 deletions src/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator


class AppBase(BaseModel):
Expand All @@ -18,14 +18,24 @@ class AppUpdate(AppBase):


class App(AppBase):
id: str = Field(alias="public_id")
public_id: str = Field(exclude=True)
id: str
name: str
metadata: dict = Field(alias="h_metadata")
h_metadata: dict = Field(exclude=True)
metadata: dict
created_at: datetime.datetime

@field_validator("metadata", mode="before")
def fetch_h_metadata(cls, value, info):
return info.data.get("h_metadata", {})

@field_validator("id", mode="before")
def internal_to_public(cls, value, info):
return info.data.get("public_id", {})

model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={"exclude": ["h_metadata", "public_id"]},
)


Expand All @@ -44,15 +54,25 @@ class UserUpdate(UserBase):


class User(UserBase):
id: str = Field(alias="public_id")
public_id: str = Field(exclude=True)
id: str
name: str
app_id: str
created_at: datetime.datetime
metadata: dict = Field(alias="h_metadata")
h_metadata: dict = Field(exclude=True)
metadata: dict

@field_validator("metadata", mode="before")
def fetch_h_metadata(cls, value, info):
return info.data.get("h_metadata", {})

@field_validator("id", mode="before")
def internal_to_public(cls, value, info):
return info.data.get("public_id", {})

model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={"exclude": ["h_metadata", "public_id"]},
)


Expand All @@ -71,16 +91,26 @@ class MessageUpdate(MessageBase):


class Message(MessageBase):
id: str = Field(alias="public_id")
public_id: str = Field(exclude=True)
id: str
content: str
is_user: bool
session_id: str
metadata: dict = Field(alias="h_metadata")
h_metadata: dict = Field(exclude=True)
metadata: dict
created_at: datetime.datetime

@field_validator("metadata", mode="before")
def fetch_h_metadata(cls, value, info):
return info.data.get("h_metadata", {})

@field_validator("id", mode="before")
def internal_to_public(cls, value, info):
return info.data.get("public_id", {})

model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={"exclude": ["h_metadata", "public_id"]},
)


Expand All @@ -97,16 +127,27 @@ class SessionUpdate(SessionBase):


class Session(SessionBase):
id: str = Field(alias="public_id")
public_id: str = Field(exclude=True)
id: str
# messages: list[Message]
is_active: bool
user_id: str
metadata: dict = Field(alias="h_metadata")
h_metadata: dict = Field(exclude=True)
metadata: dict

created_at: datetime.datetime

@field_validator("metadata", mode="before")
def fetch_h_metadata(cls, value, info):
return info.data.get("h_metadata", {})

@field_validator("id", mode="before")
def internal_to_public(cls, value, info):
return info.data.get("public_id", {})

model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={"exclude": ["h_metadata", "public_id"]},
)


Expand All @@ -128,17 +169,26 @@ class MetamessageUpdate(MetamessageBase):


class Metamessage(MetamessageBase):
id: str = Field(alias="public_id")
public_id: str = Field(exclude=True)
id: str
metamessage_type: str
content: str
message_id: str
metadata: dict = Field(alias="h_metadata")
h_metadata: dict = Field(exclude=True)
metadata: dict
created_at: datetime.datetime

@field_validator("metadata", mode="before")
def fetch_h_metadata(cls, value, info):
return info.data.get("h_metadata", {})

@field_validator("id", mode="before")
def internal_to_public(cls, value, info):
return info.data.get("public_id", {})

model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={"exclude": ["h_metadata"]},
json_schema_extra={"exclude": ["h_metadata", "public_id"]},
)


Expand All @@ -157,15 +207,25 @@ class CollectionUpdate(CollectionBase):


class Collection(CollectionBase):
id: str = Field(alias="public_id")
public_id: str = Field(exclude=True)
id: str
name: str
user_id: str
metadata: dict = Field(alias="h_metadata")
h_metadata: dict = Field(exclude=True)
metadata: dict
created_at: datetime.datetime

@field_validator("metadata", mode="before")
def fetch_h_metadata(cls, value, info):
return info.data.get("h_metadata", {})

@field_validator("id", mode="before")
def internal_to_public(cls, value, info):
return info.data.get("public_id", {})

model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={"exclude": ["h_metadata", "public_id"]},
)


Expand All @@ -184,15 +244,25 @@ class DocumentUpdate(DocumentBase):


class Document(DocumentBase):
id: str = Field(alias="public_id")
public_id: str = Field(exclude=True)
id: str
content: str
metadata: dict = Field(alias="h_metadata")
h_metadata: dict = Field(exclude=True)
metadata: dict
created_at: datetime.datetime
collection_id: str

@field_validator("metadata", mode="before")
def fetch_h_metadata(cls, value, info):
return info.data.get("h_metadata", {})

@field_validator("id", mode="before")
def internal_to_public(cls, value, info):
return info.data.get("public_id", {})

model_config = ConfigDict(
from_attributes=True,
populate_by_name=True,
json_schema_extra={"exclude": ["h_metadata", "public_id"]},
)


Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def sample_data(db_session):
await db_session.flush()

# Create test user
test_user = models.User(name=str(generate_nanoid()), app_id=test_app.id)
test_user = models.User(name=str(generate_nanoid()), app_id=test_app.public_id)
db_session.add(test_user)
await db_session.flush()

Expand Down
11 changes: 7 additions & 4 deletions tests/routes/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def test_create_app(client):
print(response)
assert response.status_code == 200
data = response.json()
print("===================")
print(data)
print("===================")
assert data["name"] == name
assert data["metadata"] == {"key": "value"}
assert "id" in data
Expand Down Expand Up @@ -45,11 +48,11 @@ def test_get_or_create_existing_app(client):

def test_get_app_by_id(client, sample_data):
test_app, _ = sample_data
response = client.get(f"/apps/{test_app.id}")
response = client.get(f"/apps/{test_app.public_id}")
assert response.status_code == 200
data = response.json()
assert data["name"] == test_app.name
assert data["id"] == str(test_app.id)
assert data["id"] == str(test_app.public_id)


def test_get_app_by_name(client, sample_data):
Expand All @@ -58,14 +61,14 @@ def test_get_app_by_name(client, sample_data):
assert response.status_code == 200
data = response.json()
assert data["name"] == test_app.name
assert data["id"] == str(test_app.id)
assert data["id"] == str(test_app.public_id)


def test_update_app(client, sample_data):
test_app, _ = sample_data
new_name = str(generate_nanoid())
response = client.put(
f"/apps/{test_app.id}",
f"/apps/{test_app.public_id}",
json={"name": new_name, "metadata": {"new_key": "new_value"}},
)
assert response.status_code == 200
Expand Down
Loading

0 comments on commit a61c7a1

Please sign in to comment.