-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for prisma backend classes
- Loading branch information
1 parent
c4206ef
commit 4d3725b
Showing
2 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import random | ||
import uuid | ||
from datetime import datetime, timedelta | ||
from typing import Any, Dict | ||
|
||
import pytest | ||
|
||
import fastagency.db | ||
import fastagency.db.prisma | ||
from fastagency.db.base import BackendDBProtocol, FrontendDBProtocol | ||
from fastagency.db.prisma import PrismaBackendDB, PrismaFrontendDB | ||
from fastagency.models.llms.azure import AzureOAIAPIKey | ||
|
||
|
||
@pytest.mark.asyncio() | ||
class TestPrismaFrontendDB: | ||
async def test_set_default(self) -> None: | ||
prisma_frontend_db = PrismaFrontendDB() | ||
with FrontendDBProtocol.set_default(prisma_frontend_db): | ||
assert FrontendDBProtocol._default_db == prisma_frontend_db | ||
|
||
async def test_db(self) -> None: | ||
prisma_frontend_db = PrismaFrontendDB() | ||
with FrontendDBProtocol.set_default(prisma_frontend_db): | ||
assert FrontendDBProtocol.db() == prisma_frontend_db | ||
|
||
async def test_create_user_get_user(self) -> None: | ||
prisma_frontend_db = PrismaFrontendDB() | ||
|
||
random_id = random.randint(1, 1_000_000) | ||
generated_uuid = str(uuid.uuid4()) | ||
email = f"user{random_id}@airt.ai" | ||
username = f"user{random_id}" | ||
|
||
user_uuid = await prisma_frontend_db._create_user( | ||
generated_uuid, email, username | ||
) | ||
assert user_uuid == generated_uuid | ||
|
||
user = await prisma_frontend_db.get_user(user_uuid) | ||
assert user["uuid"] == user_uuid | ||
assert user["email"] == email | ||
assert user["username"] == username | ||
|
||
|
||
@pytest.mark.asyncio() | ||
class TestPrismaBackendDB: | ||
async def test_set_default(self) -> None: | ||
prisma_backend_db = PrismaBackendDB() | ||
with BackendDBProtocol.set_default(prisma_backend_db): | ||
assert BackendDBProtocol._default_db == prisma_backend_db | ||
|
||
async def test_db(self) -> None: | ||
prisma_backend_db = PrismaBackendDB() | ||
with BackendDBProtocol.set_default(prisma_backend_db): | ||
assert BackendDBProtocol.db() == prisma_backend_db | ||
|
||
async def test_model_CRUD(self) -> None: # noqa: N802 | ||
# Setup | ||
prisma_frontend_db = PrismaFrontendDB() | ||
prisma_backend_db = PrismaBackendDB() | ||
random_id = random.randint(1, 1_000_000) | ||
user_uuid = await prisma_frontend_db._create_user( | ||
str(uuid.uuid4()), f"user{random_id}@airt.ai", f"user{random_id}" | ||
) | ||
model_uuid = str(uuid.uuid4()) | ||
azure_oai_api_key = AzureOAIAPIKey(api_key="whatever", name="who cares?") | ||
|
||
# Tests | ||
model = await prisma_backend_db.create_model( | ||
user_uuid=user_uuid, | ||
model_uuid=model_uuid, | ||
type_name="secret", | ||
model_name="AzureOAIAPIKey", | ||
json_str=azure_oai_api_key.model_dump_json(), | ||
) | ||
assert model["uuid"] == model_uuid | ||
assert model["user_uuid"] == user_uuid | ||
assert model["type_name"] == "secret" | ||
assert model["model_name"] == "AzureOAIAPIKey" | ||
assert model["json_str"] == azure_oai_api_key.model_dump() | ||
|
||
found_model = await prisma_backend_db.find_model(model_uuid) | ||
assert found_model["uuid"] == model_uuid | ||
|
||
many_model = await prisma_backend_db.find_many_model(user_uuid) | ||
assert len(many_model) == 1 | ||
assert many_model[0]["uuid"] == model_uuid | ||
|
||
updated_model = await prisma_backend_db.update_model( | ||
model_uuid=model_uuid, | ||
user_uuid=user_uuid, | ||
type_name="secret", | ||
model_name="AzureOAIAPIKey2", | ||
json_str=azure_oai_api_key.model_dump_json(), | ||
) | ||
assert updated_model["uuid"] == model_uuid | ||
assert updated_model["model_name"] == "AzureOAIAPIKey2" | ||
|
||
deleted_model = await prisma_backend_db.delete_model(model_uuid) | ||
assert deleted_model["uuid"] == model_uuid | ||
|
||
async def test_auth_token_CRUD(self, monkeypatch: pytest.MonkeyPatch) -> None: # noqa: N802 | ||
# Setup | ||
prisma_frontend_db = PrismaFrontendDB() | ||
prisma_backend_db = PrismaBackendDB() | ||
random_id = random.randint(1, 1_000_000) | ||
user_uuid = await prisma_frontend_db._create_user( | ||
str(uuid.uuid4()), f"user{random_id}@airt.ai", f"user{random_id}" | ||
) | ||
deployment_uuid = str(uuid.uuid4()) | ||
auth_token_uuid = str(uuid.uuid4()) | ||
|
||
async def mock_find_model(*args: Any, **kwargs: Any) -> Dict[str, str]: | ||
return { | ||
"user_uuid": user_uuid, | ||
"uuid": deployment_uuid, | ||
} | ||
|
||
monkeypatch.setattr( | ||
fastagency.db.prisma.PrismaBackendDB, | ||
"find_model", | ||
mock_find_model, | ||
) | ||
|
||
# Tests | ||
auth_token = await prisma_backend_db.create_auth_token( | ||
auth_token_uuid=auth_token_uuid, | ||
name="Test token", | ||
user_uuid=user_uuid, | ||
deployment_uuid=deployment_uuid, | ||
hashed_auth_token="whatever", | ||
expiry="99d", | ||
expires_at=datetime.utcnow() + timedelta(days=99), | ||
) | ||
assert auth_token["uuid"] == auth_token_uuid | ||
assert auth_token["name"] == "Test token" | ||
|
||
many_auth_token = await prisma_backend_db.find_many_auth_token( | ||
user_uuid, deployment_uuid | ||
) | ||
assert len(many_auth_token) == 1 | ||
assert many_auth_token[0]["uuid"] == auth_token_uuid | ||
|
||
deleted_auth_token = await prisma_backend_db.delete_auth_token( | ||
auth_token_uuid, deployment_uuid, user_uuid | ||
) | ||
assert deleted_auth_token["uuid"] == auth_token_uuid |