Skip to content

Commit

Permalink
Add tests for prisma backend classes
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaranvpl committed Jul 30, 2024
1 parent c4206ef commit 4d3725b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
Empty file added tests/db/__init__.py
Empty file.
148 changes: 148 additions & 0 deletions tests/db/test_prisma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import random
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict

import pytest

import fastagency.db
import fastagency.db.prisma
from fastagency.db.base import BackendDBProtocol, FrontendDBProtocol
from fastagency.db.prisma import PrismaBackendDB, PrismaFrontendDB
from fastagency.models.llms.azure import AzureOAIAPIKey


@pytest.mark.asyncio()
class TestPrismaFrontendDB:
async def test_set_default(self) -> None:
prisma_frontend_db = PrismaFrontendDB()
with FrontendDBProtocol.set_default(prisma_frontend_db):
assert FrontendDBProtocol._default_db == prisma_frontend_db

async def test_db(self) -> None:
prisma_frontend_db = PrismaFrontendDB()
with FrontendDBProtocol.set_default(prisma_frontend_db):
assert FrontendDBProtocol.db() == prisma_frontend_db

async def test_create_user_get_user(self) -> None:
prisma_frontend_db = PrismaFrontendDB()

random_id = random.randint(1, 1_000_000)
generated_uuid = str(uuid.uuid4())
email = f"user{random_id}@airt.ai"
username = f"user{random_id}"

user_uuid = await prisma_frontend_db._create_user(
generated_uuid, email, username
)
assert user_uuid == generated_uuid

user = await prisma_frontend_db.get_user(user_uuid)
assert user["uuid"] == user_uuid
assert user["email"] == email
assert user["username"] == username


@pytest.mark.asyncio()
class TestPrismaBackendDB:
async def test_set_default(self) -> None:
prisma_backend_db = PrismaBackendDB()
with BackendDBProtocol.set_default(prisma_backend_db):
assert BackendDBProtocol._default_db == prisma_backend_db

async def test_db(self) -> None:
prisma_backend_db = PrismaBackendDB()
with BackendDBProtocol.set_default(prisma_backend_db):
assert BackendDBProtocol.db() == prisma_backend_db

async def test_model_CRUD(self) -> None: # noqa: N802
# Setup
prisma_frontend_db = PrismaFrontendDB()
prisma_backend_db = PrismaBackendDB()
random_id = random.randint(1, 1_000_000)
user_uuid = await prisma_frontend_db._create_user(
str(uuid.uuid4()), f"user{random_id}@airt.ai", f"user{random_id}"
)
model_uuid = str(uuid.uuid4())
azure_oai_api_key = AzureOAIAPIKey(api_key="whatever", name="who cares?")

# Tests
model = await prisma_backend_db.create_model(
user_uuid=user_uuid,
model_uuid=model_uuid,
type_name="secret",
model_name="AzureOAIAPIKey",
json_str=azure_oai_api_key.model_dump_json(),
)
assert model["uuid"] == model_uuid
assert model["user_uuid"] == user_uuid
assert model["type_name"] == "secret"
assert model["model_name"] == "AzureOAIAPIKey"
assert model["json_str"] == azure_oai_api_key.model_dump()

found_model = await prisma_backend_db.find_model(model_uuid)
assert found_model["uuid"] == model_uuid

many_model = await prisma_backend_db.find_many_model(user_uuid)
assert len(many_model) == 1
assert many_model[0]["uuid"] == model_uuid

updated_model = await prisma_backend_db.update_model(
model_uuid=model_uuid,
user_uuid=user_uuid,
type_name="secret",
model_name="AzureOAIAPIKey2",
json_str=azure_oai_api_key.model_dump_json(),
)
assert updated_model["uuid"] == model_uuid
assert updated_model["model_name"] == "AzureOAIAPIKey2"

deleted_model = await prisma_backend_db.delete_model(model_uuid)
assert deleted_model["uuid"] == model_uuid

async def test_auth_token_CRUD(self, monkeypatch: pytest.MonkeyPatch) -> None: # noqa: N802
# Setup
prisma_frontend_db = PrismaFrontendDB()
prisma_backend_db = PrismaBackendDB()
random_id = random.randint(1, 1_000_000)
user_uuid = await prisma_frontend_db._create_user(
str(uuid.uuid4()), f"user{random_id}@airt.ai", f"user{random_id}"
)
deployment_uuid = str(uuid.uuid4())
auth_token_uuid = str(uuid.uuid4())

async def mock_find_model(*args: Any, **kwargs: Any) -> Dict[str, str]:
return {
"user_uuid": user_uuid,
"uuid": deployment_uuid,
}

monkeypatch.setattr(
fastagency.db.prisma.PrismaBackendDB,
"find_model",
mock_find_model,
)

# Tests
auth_token = await prisma_backend_db.create_auth_token(
auth_token_uuid=auth_token_uuid,
name="Test token",
user_uuid=user_uuid,
deployment_uuid=deployment_uuid,
hashed_auth_token="whatever",
expiry="99d",
expires_at=datetime.utcnow() + timedelta(days=99),
)
assert auth_token["uuid"] == auth_token_uuid
assert auth_token["name"] == "Test token"

many_auth_token = await prisma_backend_db.find_many_auth_token(
user_uuid, deployment_uuid
)
assert len(many_auth_token) == 1
assert many_auth_token[0]["uuid"] == auth_token_uuid

deleted_auth_token = await prisma_backend_db.delete_auth_token(
auth_token_uuid, deployment_uuid, user_uuid
)
assert deleted_auth_token["uuid"] == auth_token_uuid

0 comments on commit 4d3725b

Please sign in to comment.