From 76d22d379dfa56c7066a9e2b2b8efc01fe2e4d71 Mon Sep 17 00:00:00 2001 From: Julien Maupetit Date: Thu, 12 Dec 2024 11:19:42 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F(api)=20prefetch=20user-relat?= =?UTF-8?q?ed=20groups=20and=20operational=20units?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Decrease the number of database queries by prefetching groups and related operational units linked to the request user. --- src/api/CHANGELOG.md | 4 +++ src/api/qualicharge/auth/oidc.py | 13 ++++++-- src/api/qualicharge/db.py | 28 +++++++++++++++- src/api/tests/auth/test_oidc.py | 55 ++++++++++++++++++++++++++++++-- 4 files changed, 95 insertions(+), 5 deletions(-) diff --git a/src/api/CHANGELOG.md b/src/api/CHANGELOG.md index 5dc71329..a743d5a0 100644 --- a/src/api/CHANGELOG.md +++ b/src/api/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to ## [Unreleased] +### Changed + +- Prefetch user-related groups and operational units in `get_user` dependency + ## [0.16.0] - 2024-12-12 ### Changed diff --git a/src/api/qualicharge/auth/oidc.py b/src/api/qualicharge/auth/oidc.py index 7926c4b2..55bf5209 100644 --- a/src/api/qualicharge/auth/oidc.py +++ b/src/api/qualicharge/auth/oidc.py @@ -19,6 +19,7 @@ InvalidTokenError, ) from pydantic import AnyHttpUrl +from sqlalchemy.orm import joinedload from sqlmodel import Session as SMSession from sqlmodel import select @@ -31,7 +32,7 @@ PermissionDenied, ) from .models import IDToken -from .schemas import User +from .schemas import Group, User # API auth logger logger = logging.getLogger(__name__) @@ -157,7 +158,15 @@ def get_user( ) -> User: """Get request user.""" # Get registered user - user = session.exec(select(User).where(User.email == token.email)).one_or_none() + user = ( + session.exec( + select(User) + .options(joinedload(User.groups).joinedload(Group.operational_units)) # type: ignore[arg-type] + .where(User.email == token.email) + ) + .unique() + .one_or_none() + ) # User does not exist: raise an error if user is None: diff --git a/src/api/qualicharge/db.py b/src/api/qualicharge/db.py index 4786fcf8..3c234981 100644 --- a/src/api/qualicharge/db.py +++ b/src/api/qualicharge/db.py @@ -5,7 +5,7 @@ from pydantic import PostgresDsn from sqlalchemy import Engine as SAEngine -from sqlalchemy import text +from sqlalchemy import event, text from sqlalchemy.exc import OperationalError from sqlmodel import Session as SMSession from sqlmodel import create_engine @@ -49,6 +49,32 @@ def get_engine( return self._engine +class SAQueryCounter: + """Context manager to count SQLALchemy queries. + + Inspired by: https://stackoverflow.com/a/71337784 + """ + + def __init__(self, connection): + """Initialize the counter for a given connection.""" + self.connection = connection.engine + self.count = 0 + + def __enter__(self): + """Start listening `before_cursor_execute` event.""" + event.listen(self.connection, "before_cursor_execute", self.callback) + return self + + def __exit__(self, *args, **kwargs): + """Stop listening `before_cursor_execute` event.""" + event.remove(self.connection, "before_cursor_execute", self.callback) + + def callback(self, *args, **kwargs): + """Increment the counter every time the `before_cursor_execute` event occurs.""" + self.count += 1 + logger.debug(f"Database query [{self.count=}] >> {args=} {kwargs=}") + + def get_engine() -> SAEngine: """Get database engine.""" return Engine().get_engine( diff --git a/src/api/tests/auth/test_oidc.py b/src/api/tests/auth/test_oidc.py index 5e751119..8e27830f 100644 --- a/src/api/tests/auth/test_oidc.py +++ b/src/api/tests/auth/test_oidc.py @@ -7,21 +7,24 @@ import jwt import pytest from fastapi.security import HTTPAuthorizationCredentials, SecurityScopes +from sqlmodel import select -from qualicharge.auth.factories import IDTokenFactory, UserFactory +from qualicharge.auth.factories import GroupFactory, IDTokenFactory, UserFactory from qualicharge.auth.oidc import ( discover_provider, get_public_keys, get_token, get_user, ) -from qualicharge.auth.schemas import ScopesEnum +from qualicharge.auth.schemas import GroupOperationalUnit, ScopesEnum from qualicharge.conf import settings +from qualicharge.db import SAQueryCounter from qualicharge.exceptions import ( AuthenticationError, OIDCProviderException, PermissionDenied, ) +from qualicharge.schemas.core import OperationalUnit def setup_function(): @@ -298,3 +301,51 @@ def test_get_user_for_user_with_limited_scopes( token=id_token_factory.build(), session=db_session, ) + + +def test_get_user_number_of_queries(id_token_factory: IDTokenFactory, db_session): + """Test the OIDC get user utility number of queries for a standard user.""" + UserFactory.__session__ = db_session + GroupFactory.__session__ = db_session + + token = id_token_factory.build() + + # Create groups linked to Operational Units + groups = GroupFactory.create_batch_sync(3) + operational_units = db_session.exec(select(OperationalUnit).limit(3)).all() + for group, operational_unit in zip(groups, operational_units, strict=True): + db_session.add( + GroupOperationalUnit( + group_id=group.id, operational_unit_id=operational_unit.id + ) + ) + + # Create user linked to this groups and related operational units + user = UserFactory.create_sync( + email=token.email, + is_superuser=False, + is_active=True, + groups=groups, + scopes=[ScopesEnum.ALL_CREATE], + ) + + # Test the number of queries + with SAQueryCounter(db_session.connection()) as counter: + user = get_user( + security_scopes=SecurityScopes(scopes=[ScopesEnum.ALL_CREATE]), + token=token, + session=db_session, + ) + assert counter.count == 1 + + # When getting groups... + with SAQueryCounter(db_session.connection()) as counter: + assert {g.id for g in user.groups} == {g.id for g in groups} + assert counter.count == 0 + + # ... and related operational units + with SAQueryCounter(db_session.connection()) as counter: + assert {ou.id for g in user.groups for ou in g.operational_units} == { + ou.id for ou in operational_units + } + assert counter.count == 0