diff --git a/src/selva/ext/data/sqlalchemy/__init__.py b/src/selva/ext/data/sqlalchemy/__init__.py index ecd2a89..92606c5 100644 --- a/src/selva/ext/data/sqlalchemy/__init__.py +++ b/src/selva/ext/data/sqlalchemy/__init__.py @@ -4,7 +4,8 @@ from selva.di.container import Container from selva.di.decorator import service as service_decorator -from .service import make_engine_service, make_sessionmaker_service # noqa: F401 +from .service import make_engine_service # noqa: F401 +from .service import engine_dict_service, sessionmaker_service def selva_extension(container: Container, settings: Settings): @@ -17,8 +18,6 @@ def selva_extension(container: Container, settings: Settings): container.register( service_decorator(make_engine_service(name), name=service_name) ) - container.register( - service_decorator( - make_sessionmaker_service(service_name), name=service_name - ) - ) + + container.register(service_decorator(engine_dict_service)) + container.register(service_decorator(sessionmaker_service)) diff --git a/src/selva/ext/data/sqlalchemy/service.py b/src/selva/ext/data/sqlalchemy/service.py index e46aafc..58b83e9 100644 --- a/src/selva/ext/data/sqlalchemy/service.py +++ b/src/selva/ext/data/sqlalchemy/service.py @@ -3,7 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from selva.configuration.settings import Settings -from selva.di import Inject +from selva.di import Container, Inject from selva.ext.data.sqlalchemy.settings import SqlAlchemySettings @@ -26,10 +26,16 @@ async def engine_service( return engine_service -def make_sessionmaker_service(name: str): - async def sessionmaker_service( - engine: Annotated[AsyncEngine, Inject(name=name)], - ) -> async_sessionmaker: - return async_sessionmaker(bind=engine) +async def engine_dict_service( + di: Container, settings: Settings +) -> dict[str, AsyncEngine]: + return { + db: await di.get(AsyncEngine, name=db if db != "default" else None) + for db in settings.data.sqlalchemy + } - return sessionmaker_service + +async def sessionmaker_service(engines: dict[str, AsyncEngine]) -> async_sessionmaker: + default = engines.pop("default", None) + + return async_sessionmaker(bind=default, binds=engines, expire_on_commit=False) diff --git a/tests/ext/data/sqlalchemy/application_named.py b/tests/ext/data/sqlalchemy/application_named.py index 8add2cb..32cf315 100644 --- a/tests/ext/data/sqlalchemy/application_named.py +++ b/tests/ext/data/sqlalchemy/application_named.py @@ -2,7 +2,7 @@ from asgikit.responses import respond_text from sqlalchemy import text -from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine from selva.di import Inject from selva.web import controller, get @@ -10,12 +10,12 @@ @controller class Controller: - sessionmaker: Annotated[async_sessionmaker, Inject(name="other")] + engine: Annotated[AsyncEngine, Inject(name="other")] @get async def index(self, request): - async with self.sessionmaker() as session: - result = await session.execute(text("select sqlite_version()")) + async with self.engine.begin() as conn: + result = await conn.execute(text("select sqlite_version()")) version = result.first()[0] await respond_text(request.response, version) diff --git a/tests/ext/data/sqlalchemy/test_application.py b/tests/ext/data/sqlalchemy/test_application.py index 8c82320..9629ca4 100644 --- a/tests/ext/data/sqlalchemy/test_application.py +++ b/tests/ext/data/sqlalchemy/test_application.py @@ -33,4 +33,4 @@ async def test_application(application: str, database: str): client = AsyncClient(app=app) response = await client.get("http://localhost:8000/") - assert response.status_code == HTTPStatus.OK + assert response.status_code == HTTPStatus.OK, response.text diff --git a/tests/ext/data/sqlalchemy/test_service.py b/tests/ext/data/sqlalchemy/test_service.py index bf0433c..5bda8bf 100644 --- a/tests/ext/data/sqlalchemy/test_service.py +++ b/tests/ext/data/sqlalchemy/test_service.py @@ -1,10 +1,13 @@ from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from selva.configuration.defaults import default_settings from selva.configuration.settings import Settings +from selva.di.container import Container from selva.ext.data.sqlalchemy.service import ( + engine_dict_service, make_engine_service, - make_sessionmaker_service, + sessionmaker_service, ) @@ -96,17 +99,13 @@ async def test_make_engine_service_with_execution_options(): } ) - engine_service = make_engine_service("default") - async for engine in engine_service(settings): - sessionmaker_service = make_sessionmaker_service("default") - sessionmaker = await sessionmaker_service(engine) + async for engine in make_engine_service("default")(settings): + async with engine.connect() as conn: + result = await conn.execute(text("select 1")) + isolation_level = result.context.execution_options["isolation_level"] + assert isolation_level == "READ UNCOMMITTED" - async with sessionmaker() as session: - result = await session.execute(text("select 1")) - assert ( - result.context.execution_options["isolation_level"] - == "READ UNCOMMITTED" - ) + await engine.dispose() async def test_make_engine_service_alternative_name(): @@ -126,3 +125,47 @@ async def test_make_engine_service_alternative_name(): engine_service = make_engine_service("other")(settings) async for engine in engine_service: assert engine is not None + await engine.dispose() + + +async def test_sessionmaker_service(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + sessionmaker = await sessionmaker_service({"default": engine}) + async with sessionmaker() as session: + result = await session.execute(text("select 1")) + assert result.scalar() == 1 + + await engine.dispose() + + +async def test_engine_dict_service(): + ioc = Container() + + settings = Settings( + default_settings + | { + "data": { + "sqlalchemy": { + "default": { + "url": "sqlite+aiosqlite:///:memory:", + }, + "other": { + "url": "sqlite+aiosqlite:///:memory:", + }, + }, + }, + } + ) + + ioc.define(Settings, settings) + ioc.define(AsyncEngine, create_async_engine(settings.data.sqlalchemy.default.url)) + ioc.define( + AsyncEngine, + create_async_engine(settings.data.sqlalchemy.other.url), + name="other", + ) + + engines = await engine_dict_service(ioc, settings) + + assert set(engines.keys()) == {"default", "other"} diff --git a/tests/ext/data/sqlalchemy/test_service_postgres.py b/tests/ext/data/sqlalchemy/test_service_postgres.py index d1eece8..1bbf344 100644 --- a/tests/ext/data/sqlalchemy/test_service_postgres.py +++ b/tests/ext/data/sqlalchemy/test_service_postgres.py @@ -3,13 +3,11 @@ import pytest from sqlalchemy import make_url, text +from sqlalchemy.ext.asyncio import create_async_engine from selva.configuration.defaults import default_settings from selva.configuration.settings import Settings -from selva.ext.data.sqlalchemy.service import ( - make_engine_service, - make_sessionmaker_service, -) +from selva.ext.data.sqlalchemy.service import make_engine_service, sessionmaker_service from .test_service import _test_engine_service @@ -122,11 +120,19 @@ async def test_make_engine_service_with_execution_options(): } ) - engine_service = make_engine_service("default") - engine = await anext(engine_service(settings)) - sessionmaker_service = make_sessionmaker_service("default") - sessionmaker = await sessionmaker_service(engine) + async for engine in make_engine_service("default")(settings): + async with engine.connect() as conn: + result = await conn.execute(text("select 1")) + isolation_level = result.context.execution_options["isolation_level"] + assert isolation_level == "READ COMMITTED" + +async def test_sessionmaker_service(): + engine = create_async_engine(POSTGRES_URL) + + sessionmaker = await sessionmaker_service({"default": engine}) async with sessionmaker() as session: result = await session.execute(text("select 1")) - assert result.context.execution_options["isolation_level"] == "READ COMMITTED" + assert result.scalar() == 1 + + await engine.dispose()