From 7b8b6b0d9bb2b730c9e34be8084be1fca44e53f3 Mon Sep 17 00:00:00 2001 From: vikbhas Date: Wed, 6 Nov 2024 10:31:38 +0100 Subject: [PATCH] Feat: Introduce ASYNC DB as Plug and Play --- creyPY/fastapi/db/__init__.py | 1 + creyPY/fastapi/db/async_session.py | 25 ++++++++++ creyPY/fastapi/pagination.py | 77 ++++++++++++++++++++++++++++-- requirements.txt | 3 ++ 4 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 creyPY/fastapi/db/async_session.py diff --git a/creyPY/fastapi/db/__init__.py b/creyPY/fastapi/db/__init__.py index d13ef11..395efc9 100644 --- a/creyPY/fastapi/db/__init__.py +++ b/creyPY/fastapi/db/__init__.py @@ -1 +1,2 @@ from .session import * # noqa +from .async_session import * # noqa diff --git a/creyPY/fastapi/db/async_session.py b/creyPY/fastapi/db/async_session.py new file mode 100644 index 0000000..6856914 --- /dev/null +++ b/creyPY/fastapi/db/async_session.py @@ -0,0 +1,25 @@ +import os +from typing import AsyncGenerator +from dotenv import load_dotenv +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker + + +load_dotenv() + +host = os.getenv("POSTGRES_HOST", "localhost") +user = os.getenv("POSTGRES_USER", "postgres") +password = os.getenv("POSTGRES_PASSWORD", "root") +port = os.getenv("POSTGRES_PORT", "5432") +name = os.getenv("POSTGRES_DB", "fastapi") + +SQLALCHEMY_DATABASE_URL = f"postgresql+psycopg://{user}:{password}@{host}:{port}/" + + +async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True) +AsyncSessionLocal = sessionmaker(bind=async_engine, class_=AsyncSession, + expire_on_commit=False, autoflush=False, autocommit=False) + +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSessionLocal() as db: + yield db diff --git a/creyPY/fastapi/pagination.py b/creyPY/fastapi/pagination.py index f96f698..d747c2d 100644 --- a/creyPY/fastapi/pagination.py +++ b/creyPY/fastapi/pagination.py @@ -1,5 +1,6 @@ from math import ceil -from typing import Any, Generic, Optional, Self, Sequence, TypeVar, Union +from typing import Any, Generic, Optional, Self, Sequence, TypeVar, Union, overload +from contextlib import suppress from pydantic import BaseModel from fastapi_pagination import Params from fastapi_pagination.bases import AbstractPage, AbstractParams @@ -8,6 +9,8 @@ GreaterEqualZero, AdditionalData, SyncItemsTransformer, + AsyncItemsTransformer, + ItemsTransformer ) from fastapi_pagination.api import create_page, apply_items_transformer from fastapi_pagination.utils import verify_params @@ -17,7 +20,9 @@ from sqlalchemy.sql.selectable import Select from sqlalchemy.orm.session import Session from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session from fastapi import Query +from sqlalchemy.util import await_only, greenlet_spawn T = TypeVar("T") @@ -107,19 +112,59 @@ def unwrap_scalars( ) -> Union[Sequence[T], Sequence[Sequence[T]]]: return [item[0] if force_unwrap else item for item in items] +def _get_sync_conn_from_async(conn: Any) -> Session: # pragma: no cover + if isinstance(conn, async_scoped_session): + conn = conn() + with suppress(AttributeError): + return conn.sync_session # type: ignore + + with suppress(AttributeError): + return conn.sync_connection # type: ignore + + raise TypeError("conn must be an AsyncConnection or AsyncSession") + +@overload def paginate( connection: Session, query: Select, params: Optional[AbstractParams] = None, transformer: Optional[SyncItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, +) -> Any: + pass + + +@overload +async def paginate( + connection: AsyncSession, + query: Select, + params: Optional[AbstractParams] = None, + transformer: Optional[AsyncItemsTransformer] = None, + additional_data: Optional[AdditionalData] = None, +) -> Any: + pass + + +def _paginate( + connection: Session, + query: Select, + params: Optional[AbstractParams] = None, + transformer: Optional[ItemsTransformer] = None, + additional_data: Optional[AdditionalData] = None, + async_:bool = False ): + + if async_: + def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any: + return await_only(apply_items_transformer(*args, **kwargs, async_=True)) + else: + _apply_items_transformer = apply_items_transformer params, raw_params = verify_params(params, "limit-offset", "cursor") count_query = create_count_query(query) total = connection.scalar(count_query) - + if params.pagination is False and total > 0: params = Params(page=1, size=total) else: @@ -129,7 +174,7 @@ def paginate( items = connection.execute(query).all() items = unwrap_scalars(items) - t_items = apply_items_transformer(items, transformer) + t_items = _apply_items_transformer(items, transformer) return create_page( t_items, @@ -137,3 +182,29 @@ def paginate( total=total, **(additional_data or {}), ) + +def paginate( + connection: Session, + query: Select, + params: Optional[AbstractParams] = None, + transformer: Optional[ItemsTransformer] = None, + additional_data: Optional[AdditionalData] = None, +): + if isinstance(connection,AsyncSession): + connection = _get_sync_conn_from_async(connection) + return greenlet_spawn(_paginate, + connection, + query, + params, + transformer, + additional_data, + async_=True) + + return _paginate( + connection, + query, + params, + transformer, + additional_data, + async_=False + ) diff --git a/requirements.txt b/requirements.txt index f8f5ec2..a03bdfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,6 @@ psycopg-pool>=3.2.2 # PostgreSQL h11>=0.14.0 # Testing httpcore>=1.0.5 # Testing httpx>=0.27.0 # Testing + +asyncpg>=0.30.0 #SQLAlchemy +greenlet>=3.1.1 #Async