From ac3e5202c94308442752a86a0c4b0ced662e0c56 Mon Sep 17 00:00:00 2001 From: Kumaran Rajendhiran Date: Fri, 26 Jul 2024 13:33:26 +0000 Subject: [PATCH] Use protocol --- fastagency/db/base.py | 36 ++++++++++++++++++++---------------- fastagency/db/prisma.py | 15 ++++++++++----- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/fastagency/db/base.py b/fastagency/db/base.py index d76ef178b..cb87f6b48 100644 --- a/fastagency/db/base.py +++ b/fastagency/db/base.py @@ -1,36 +1,40 @@ from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Dict +from typing import ( + Any, + AsyncGenerator, + Dict, + Protocol, + runtime_checkable, +) from prisma import Prisma # type: ignore[attr-defined] from prisma.actions import AuthTokenActions, ModelActions -class BaseProtocol: - async def get_db_url(self) -> str: - raise NotImplementedError() +@runtime_checkable +class BaseProtocol(Protocol): + @staticmethod + async def get_db_url() -> str: ... @asynccontextmanager # type: ignore[arg-type] - async def get_db_connection(self) -> AsyncGenerator[Prisma, None]: - raise NotImplementedError() + async def get_db_connection(self) -> AsyncGenerator[Prisma, None]: ... -class BaseBackendProtocol(BaseProtocol): - async def find_model_using_raw(self, model_uuid: str) -> Dict[str, Any]: - raise NotImplementedError() +@runtime_checkable +class BaseBackendProtocol(BaseProtocol, Protocol): + async def find_model_using_raw(self, model_uuid: str) -> Dict[str, Any]: ... @asynccontextmanager # type: ignore[arg-type] async def get_model_connection( self, - ) -> AsyncGenerator[ModelActions[Any], None]: - raise NotImplementedError() + ) -> AsyncGenerator[ModelActions[Any], None]: ... @asynccontextmanager # type: ignore[arg-type] async def get_authtoken_connection( self, - ) -> AsyncGenerator[AuthTokenActions[Any], None]: - raise NotImplementedError() + ) -> AsyncGenerator[AuthTokenActions[Any], None]: ... -class BaseFrontendProtocol(BaseProtocol): - async def get_user(self, user_uuid: str) -> Dict[str, Any]: - raise NotImplementedError() +@runtime_checkable +class BaseFrontendProtocol(BaseProtocol, Protocol): + async def get_user(self, user_uuid: str) -> Dict[str, Any]: ... diff --git a/fastagency/db/prisma.py b/fastagency/db/prisma.py index a4e10bc45..26fbb4ec4 100644 --- a/fastagency/db/prisma.py +++ b/fastagency/db/prisma.py @@ -7,11 +7,15 @@ from prisma import Prisma # type: ignore[attr-defined] from prisma.actions import AuthTokenActions, ModelActions -from .base import BaseBackendProtocol, BaseFrontendProtocol +from .base import ( + BaseBackendProtocol, + BaseFrontendProtocol, +) class BackendDBProtocol(BaseBackendProtocol): - async def get_db_url(self) -> str: + @staticmethod + async def get_db_url() -> str: db_url: str = environ.get("PY_DATABASE_URL", None) # type: ignore[assignment,arg-type] if not db_url: raise ValueError( @@ -25,7 +29,7 @@ async def get_db_url(self) -> str: async def get_db_connection( # type: ignore[override] self, ) -> AsyncGenerator[Prisma, None]: - db_url = await self.get_db_url() + db_url = await BackendDBProtocol.get_db_url() db = Prisma(datasource={"url": db_url}) await db.connect() try: @@ -65,7 +69,8 @@ async def get_authtoken_connection( # type: ignore[override] class FrontendDBProtocol(BaseFrontendProtocol): - async def get_db_url(self) -> str: + @staticmethod + async def get_db_url() -> str: db_url: str = environ.get("DATABASE_URL", None) # type: ignore[assignment,arg-type] if not db_url: raise ValueError( @@ -79,7 +84,7 @@ async def get_db_url(self) -> str: async def get_db_connection( # type: ignore[override] self, ) -> AsyncGenerator[Prisma, None]: - db_url = await self.get_db_url() + db_url = await FrontendDBProtocol.get_db_url() db = Prisma(datasource={"url": db_url}) await db.connect() try: