From ff8a0684a099a5c93df60b9fdc34b4f9b1e8786c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tao=20Bojl=C3=A9n?= <66130243+taobojlen@users.noreply.github.com> Date: Fri, 19 Jul 2024 00:09:23 +0100 Subject: [PATCH] fix: ensure context is reset after leaving --- requirements-dev.txt | 15 +------ src/zealot/listeners.py | 91 +++++++++++++++++++++++------------------ tests/test_listeners.py | 53 +++++++++++++++++++++--- tests/test_nplusones.py | 10 +++++ 4 files changed, 110 insertions(+), 59 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index b5d4b45..e7be097 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,8 +4,6 @@ asgiref==3.8.1 # via # django # django-stubs -backports-tarfile==1.2.0 - # via jaraco-context build==1.2.1 # via -r requirements-dev.in certifi==2024.6.2 @@ -23,8 +21,6 @@ django-stubs-ext==5.0.2 # via django-stubs docutils==0.21.2 # via readme-renderer -exceptiongroup==1.2.1 - # via pytest factory-boy==3.3.0 # via -r requirements-dev.in faker==26.0.0 @@ -32,10 +28,7 @@ faker==26.0.0 idna==3.7 # via requests importlib-metadata==8.0.0 - # via - # build - # keyring - # twine + # via twine iniconfig==2.0.0 # via pytest jaraco-classes==3.4.0 @@ -103,18 +96,12 @@ six==1.16.0 # via python-dateutil sqlparse==0.5.0 # via django -tomli==2.0.1 - # via - # build - # django-stubs - # pytest twine==5.1.1 # via -r requirements-dev.in types-pyyaml==6.0.12.20240311 # via django-stubs typing-extensions==4.12.2 # via - # asgiref # django-stubs # django-stubs-ext urllib3==2.2.2 diff --git a/src/zealot/listeners.py b/src/zealot/listeners.py index 353b6e4..304ae81 100644 --- a/src/zealot/listeners.py +++ b/src/zealot/listeners.py @@ -3,8 +3,9 @@ from collections import defaultdict from contextlib import contextmanager from contextvars import ContextVar, Token +from dataclasses import dataclass, field from fnmatch import fnmatch -from typing import Optional, TypedDict, Union +from typing import Optional, TypedDict from django.conf import settings from django.db import models @@ -20,10 +21,24 @@ class QuerySource(TypedDict): instance_key: Optional[str] # e.g. `User:123` -# None means not initialized -# bool means initialized, in/not in zealot context -_is_in_context: ContextVar[Union[None, bool]] = ContextVar( - "in_context", default=None +# tuple of (model, field, caller) +CountsKey = tuple[type[models.Model], str, str] + + +@dataclass +class NPlusOneContext: + # None means not initialized + # bool means initialized, in/not in zealot context + is_in_context: Optional[bool] = None + counts: dict[CountsKey, int] = field( + default_factory=lambda: defaultdict(int) + ) + ignored: set[str] = field(default_factory=set) + + +_nplusone_context: ContextVar[NPlusOneContext] = ContextVar( + "nplusone", + default=NPlusOneContext(), ) logger = logging.getLogger("zealot") @@ -38,9 +53,6 @@ class Listener(ABC): @abstractmethod def notify(self, *args, **kwargs): ... - @abstractmethod - def reset(self): ... - @property @abstractmethod def error_class(self) -> type[ZealotError]: ... @@ -80,12 +92,6 @@ def _alert(self, model: type[models.Model], field: str, message: str): class NPlusOneListener(Listener): - ignored_instances: set[str] - counts: dict[tuple[type[models.Model], str, str], int] - - def __init__(self): - self.reset() - @property def error_class(self): return NPlusOneError @@ -96,19 +102,18 @@ def notify( field: str, instance_key: Optional[str], ): - if not _is_in_context.get(): + context = _nplusone_context.get() + if not context.is_in_context: return caller = get_caller() key = (model, field, f"{caller.filename}:{caller.lineno}") - self.counts[key] += 1 - count = self.counts[key] - if ( - count >= self._threshold - and instance_key not in self.ignored_instances - ): + context.counts[key] += 1 + count = context.counts[key] + if count >= self._threshold and instance_key not in context.ignored: message = f"N+1 detected on {model.__name__}.{field}" self._alert(model, field, message) + _nplusone_context.set(context) def ignore(self, instance_key: Optional[str]): """ @@ -117,15 +122,13 @@ def ignore(self, instance_key: Optional[str]): This is used when the given instance is singly-loaded, e.g. via `.first()` or `.get()`. This is to prevent false positives. """ - if not _is_in_context.get(): + context = _nplusone_context.get() + if not context.is_in_context: return if not instance_key: return - self.ignored_instances.add(instance_key) - - def reset(self): - self.counts = defaultdict(int) - self.ignored_instances = set() + context.ignored.add(instance_key) + _nplusone_context.set(context) @property def _threshold(self) -> int: @@ -138,21 +141,25 @@ def _threshold(self) -> int: n_plus_one_listener = NPlusOneListener() -def setup() -> Token: - new_context_value = True - if _is_in_context.get() is False: - # if we're already in an ignore-context, we don't want to override - # it. - new_context_value = False - return _is_in_context.set(new_context_value) +def setup() -> Optional[Token]: + # if we're already in an ignore-context, we don't want to override + # it. + context = _nplusone_context.get() + if context.is_in_context is False: + new_is_in_context = False + else: + new_is_in_context = True + + return _nplusone_context.set( + NPlusOneContext(is_in_context=new_is_in_context) + ) def teardown(token: Optional[Token] = None): - n_plus_one_listener.reset() if token: - _is_in_context.reset(token) + _nplusone_context.reset(token) else: - _is_in_context.set(False) + _nplusone_context.set(NPlusOneContext()) @contextmanager @@ -166,8 +173,14 @@ def zealot_context(): @contextmanager def zealot_ignore(): - token = _is_in_context.set(False) + old_context = _nplusone_context.get() + new_context = NPlusOneContext( + counts=old_context.counts.copy(), + ignored=old_context.ignored.copy(), + is_in_context=False, + ) + token = _nplusone_context.set(new_context) try: yield finally: - _is_in_context.reset(token) + _nplusone_context.reset(token) diff --git a/tests/test_listeners.py b/tests/test_listeners.py index d26343d..924a6d3 100644 --- a/tests/test_listeners.py +++ b/tests/test_listeners.py @@ -3,8 +3,8 @@ import pytest from djangoproject.social.models import Post, User -from zealot import zealot_context, zealot_ignore -from zealot.errors import NPlusOneError +from zealot import NPlusOneError, zealot_context, zealot_ignore +from zealot.listeners import _nplusone_context, n_plus_one_listener from .factories import PostFactory, UserFactory @@ -105,11 +105,52 @@ def test_ignore_context_takes_precedence(): _ = list(user.posts.all()) -def test_ignores_calls_on_different_lines(): +def test_reverts_to_previous_state_when_leaving_zealot_ignore(): + # we are currently in a zealot context + assert _nplusone_context.get().is_in_context is True + with zealot_ignore(): + assert _nplusone_context.get().is_in_context is False + assert _nplusone_context.get().is_in_context is True + + # if we start off *without* being in a context, that also gets reset + context = _nplusone_context.get() + context.is_in_context = None + _nplusone_context.set(context) + + assert _nplusone_context.get().is_in_context is None + with zealot_ignore(): + assert _nplusone_context.get().is_in_context is False + assert _nplusone_context.get().is_in_context is None + + +def test_resets_state_in_nested_context(): [user_1, user_2] = UserFactory.create_batch(2) PostFactory.create(author=user_1) PostFactory.create(author=user_2) - # this should *not* raise an exception - _a = list(user_1.posts.all()) - _b = list(user_2.posts.all()) + # we're already in a zealot_context within each test, so let's set + # some state. + n_plus_one_listener.ignore("Test:1") + n_plus_one_listener.notify(Post, "test_field", "Post:1") + + context = _nplusone_context.get() + assert context.ignored == {"Test:1"} + assert list(context.counts.values()) == [1] + + with zealot_context(): + # new context, fresh state + context = _nplusone_context.get() + assert context.ignored == set() + assert list(context.counts.values()) == [] + + n_plus_one_listener.ignore("NestedTest:1") + n_plus_one_listener.notify(Post, "nested_test_field", "Post:1") + + context = _nplusone_context.get() + assert context.ignored == {"NestedTest:1"} + assert list(context.counts.values()) == [1] + + # back outside the nested context, we're back to the old state + context = _nplusone_context.get() + assert context.ignored == {"Test:1"} + assert list(context.counts.values()) == [1] diff --git a/tests/test_nplusones.py b/tests/test_nplusones.py index f27ec7f..99a31e9 100644 --- a/tests/test_nplusones.py +++ b/tests/test_nplusones.py @@ -475,3 +475,13 @@ def test_works_in_web_requests(client): assert response.status_code == 200 response = client.get(f"/user/{user_2.pk}/") assert response.status_code == 200 + + +def test_ignores_calls_on_different_lines(): + [user_1, user_2] = UserFactory.create_batch(2) + PostFactory.create(author=user_1) + PostFactory.create(author=user_2) + + # this should *not* raise an exception + _a = list(user_1.posts.all()) + _b = list(user_2.posts.all())