diff --git a/app/api/deps.py b/app/api/deps.py index a7a2cd6..dbc9300 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -2,7 +2,7 @@ Dependencies """ -from typing import Annotated +from typing import Annotated, AsyncGenerator from fastapi import Depends from fastapi.security import APIKeyHeader @@ -11,8 +11,8 @@ from app.models.users.user import User as _User -async def get_logic() -> _Logic: - async with Logic.create() as logic: +async def get_logic() -> AsyncGenerator[_Logic, None]: + async with _Logic.create() as logic: yield logic @@ -20,7 +20,7 @@ async def get_logic() -> _Logic: async def get_user( - token: Annotated[str, Depends(APIKeyHeader(name='access-token'))], + token: Annotated[str, Depends(APIKeyHeader(name="access-token"))], logic: Logic, ) -> _User | None: return await logic.users.retrieve_by_token(token) diff --git a/app/api/v1/auth/token.py b/app/api/v1/auth/token.py index e7074d2..6ea9ed8 100644 --- a/app/api/v1/auth/token.py +++ b/app/api/v1/auth/token.py @@ -4,15 +4,15 @@ from app.models.auth import AccessToken from app.models.users.user import UserCreate -router = APIRouter(prefix='/token') +router = APIRouter(prefix="/token") -@router.post('/', response_model=AccessToken) +@router.post("/", response_model=AccessToken) async def token(data: UserCreate, logic: deps.Logic): """ Retrieve new access token """ - return await logic.auth.generate_token(**data.model_dump()) + return await logic.auth.generate_token(data) -__all__ = ['router'] +__all__ = ["router"] diff --git a/app/core/db.py b/app/core/db.py index ffbbffd..7ba6cf8 100644 --- a/app/core/db.py +++ b/app/core/db.py @@ -4,8 +4,7 @@ from typing import Self -from sqlalchemy.ext.asyncio import (AsyncEngine, async_sessionmaker, - create_async_engine) +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession from app import repositories as repos @@ -13,22 +12,13 @@ class Database: - _instance = None - - def __new__(cls, *args, **kwargs) -> 'Database': - if cls._instance is None: - cls._instance = super(Database, cls).__new__(cls) - return cls._instance - def __init__( self, engine: AsyncEngine | None = None, session: AsyncSession | None = None, ) -> None: - if not hasattr(self, 'initialized'): - self.__engine = engine - self.__session = session - self.initialized = True + self.__engine = engine + self.__session = session async def __set_async_engine(self) -> None: if self.__engine is None: @@ -60,4 +50,3 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: if self.__session is not None: await self.__session.commit() await self.__session.close() - self.__session = None diff --git a/app/logic/auth/auth.py b/app/logic/auth/auth.py index aa45007..6027a51 100644 --- a/app/logic/auth/auth.py +++ b/app/logic/auth/auth.py @@ -2,23 +2,20 @@ from app.core import exps from app.models.auth import AccessToken +from app.models.users.user import UserCreate if TYPE_CHECKING: from app.logic import Logic class Auth: - def __init__(self, logic: 'Logic'): + def __init__(self, logic: "Logic"): self.logic = logic - async def generate_token( - self, email: str, password: str - ) -> AccessToken | None: - if (user := await self.logic.db.user.retrieve_by_email(email)) is None: + async def generate_token(self, data: UserCreate) -> AccessToken | None: + if (user := await self.logic.db.user.retrieve_by_email(data.email)) is None: raise exps.UserNotFoundException() - if not self.logic.security.pwd.checkpwd(password, user.password): + if not self.logic.security.pwd.checkpwd(data.password, user.password): raise exps.UserIsCorrectException() - access_token = self.logic.security.jwt.encode_token( - {'id': user.id}, 1440 - ) + access_token = self.logic.security.jwt.encode_token({"id": user.id}, 1440) return AccessToken(token=access_token) diff --git a/app/logic/logic.py b/app/logic/logic.py index 5287728..29c3eab 100644 --- a/app/logic/logic.py +++ b/app/logic/logic.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager -from typing import AsyncGenerator, Self +from typing import AsyncGenerator from app.core.db import Database @@ -11,12 +11,12 @@ class Logic: def __init__(self, db: Database): self.db = db - self.security = Security() - self.users = Users(self) self.auth = Auth(self) + self.users = Users(self) + self.security = Security() @classmethod @asynccontextmanager - async def create(cls) -> AsyncGenerator[Self, None]: + async def create(cls) -> AsyncGenerator["Logic", None]: async with Database() as db: yield cls(db) diff --git a/app/logic/users/users.py b/app/logic/users/users.py index 16073dc..95479f3 100644 --- a/app/logic/users/users.py +++ b/app/logic/users/users.py @@ -8,23 +8,19 @@ class Users: - def __init__(self, logic: 'Logic'): + def __init__(self, logic: "Logic"): self.logic = logic - async def create(self, model: UserCreate) -> User | None: - if await self.logic.db.user.retrieve_by_email(model.email): + async def create(self, data: UserCreate) -> User | None: + if await self.logic.db.user.retrieve_by_email(data.email): raise exps.UserExistsException() - model.password = self.logic.security.pwd.hashpwd(model.password) - user = await self.logic.db.user.create(model) + data.password = self.logic.security.pwd.hashpwd(data.password) + user = await self.logic.db.user.create(data) return user async def retrieve_by_token(self, token: str) -> User | None: payload = self.logic.security.jwt.decode_token(token) - if not ( - user := await self.logic.db.user.retrieve_one( - ident=payload.get('id') - ) - ): + if not (user := await self.logic.db.user.retrieve_one(ident=payload.get("id"))): raise exps.UserNotFoundException() return user