Skip to content

Commit

Permalink
Use protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaranvpl committed Jul 26, 2024
1 parent a677fd3 commit ac3e520
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
36 changes: 20 additions & 16 deletions fastagency/db/base.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict
from typing import (
Any,
AsyncGenerator,
Dict,
Protocol,
runtime_checkable,
)

from prisma import Prisma # type: ignore[attr-defined]
from prisma.actions import AuthTokenActions, ModelActions


class BaseProtocol:
async def get_db_url(self) -> str:
raise NotImplementedError()
@runtime_checkable
class BaseProtocol(Protocol):
@staticmethod
async def get_db_url() -> str: ...

@asynccontextmanager # type: ignore[arg-type]
async def get_db_connection(self) -> AsyncGenerator[Prisma, None]:
raise NotImplementedError()
async def get_db_connection(self) -> AsyncGenerator[Prisma, None]: ...


class BaseBackendProtocol(BaseProtocol):
async def find_model_using_raw(self, model_uuid: str) -> Dict[str, Any]:
raise NotImplementedError()
@runtime_checkable
class BaseBackendProtocol(BaseProtocol, Protocol):
async def find_model_using_raw(self, model_uuid: str) -> Dict[str, Any]: ...

@asynccontextmanager # type: ignore[arg-type]
async def get_model_connection(
self,
) -> AsyncGenerator[ModelActions[Any], None]:
raise NotImplementedError()
) -> AsyncGenerator[ModelActions[Any], None]: ...

@asynccontextmanager # type: ignore[arg-type]
async def get_authtoken_connection(
self,
) -> AsyncGenerator[AuthTokenActions[Any], None]:
raise NotImplementedError()
) -> AsyncGenerator[AuthTokenActions[Any], None]: ...


class BaseFrontendProtocol(BaseProtocol):
async def get_user(self, user_uuid: str) -> Dict[str, Any]:
raise NotImplementedError()
@runtime_checkable
class BaseFrontendProtocol(BaseProtocol, Protocol):
async def get_user(self, user_uuid: str) -> Dict[str, Any]: ...
15 changes: 10 additions & 5 deletions fastagency/db/prisma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
from prisma import Prisma # type: ignore[attr-defined]
from prisma.actions import AuthTokenActions, ModelActions

from .base import BaseBackendProtocol, BaseFrontendProtocol
from .base import (
BaseBackendProtocol,
BaseFrontendProtocol,
)


class BackendDBProtocol(BaseBackendProtocol):
async def get_db_url(self) -> str:
@staticmethod
async def get_db_url() -> str:
db_url: str = environ.get("PY_DATABASE_URL", None) # type: ignore[assignment,arg-type]
if not db_url:
raise ValueError(
Expand All @@ -25,7 +29,7 @@ async def get_db_url(self) -> str:
async def get_db_connection( # type: ignore[override]
self,
) -> AsyncGenerator[Prisma, None]:
db_url = await self.get_db_url()
db_url = await BackendDBProtocol.get_db_url()
db = Prisma(datasource={"url": db_url})
await db.connect()
try:
Expand Down Expand Up @@ -65,7 +69,8 @@ async def get_authtoken_connection( # type: ignore[override]


class FrontendDBProtocol(BaseFrontendProtocol):
async def get_db_url(self) -> str:
@staticmethod
async def get_db_url() -> str:
db_url: str = environ.get("DATABASE_URL", None) # type: ignore[assignment,arg-type]
if not db_url:
raise ValueError(
Expand All @@ -79,7 +84,7 @@ async def get_db_url(self) -> str:
async def get_db_connection( # type: ignore[override]
self,
) -> AsyncGenerator[Prisma, None]:
db_url = await self.get_db_url()
db_url = await FrontendDBProtocol.get_db_url()
db = Prisma(datasource={"url": db_url})
await db.connect()
try:
Expand Down

0 comments on commit ac3e520

Please sign in to comment.