Skip to content

Commit

Permalink
fix: ensure context is reset after leaving
Browse files Browse the repository at this point in the history
  • Loading branch information
taobojlen committed Jul 18, 2024
1 parent 4e6cc06 commit ff8a068
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 59 deletions.
15 changes: 1 addition & 14 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ asgiref==3.8.1
# via
# django
# django-stubs
backports-tarfile==1.2.0
# via jaraco-context
build==1.2.1
# via -r requirements-dev.in
certifi==2024.6.2
Expand All @@ -23,19 +21,14 @@ django-stubs-ext==5.0.2
# via django-stubs
docutils==0.21.2
# via readme-renderer
exceptiongroup==1.2.1
# via pytest
factory-boy==3.3.0
# via -r requirements-dev.in
faker==26.0.0
# via factory-boy
idna==3.7
# via requests
importlib-metadata==8.0.0
# via
# build
# keyring
# twine
# via twine
iniconfig==2.0.0
# via pytest
jaraco-classes==3.4.0
Expand Down Expand Up @@ -103,18 +96,12 @@ six==1.16.0
# via python-dateutil
sqlparse==0.5.0
# via django
tomli==2.0.1
# via
# build
# django-stubs
# pytest
twine==5.1.1
# via -r requirements-dev.in
types-pyyaml==6.0.12.20240311
# via django-stubs
typing-extensions==4.12.2
# via
# asgiref
# django-stubs
# django-stubs-ext
urllib3==2.2.2
Expand Down
91 changes: 52 additions & 39 deletions src/zealot/listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from collections import defaultdict
from contextlib import contextmanager
from contextvars import ContextVar, Token
from dataclasses import dataclass, field
from fnmatch import fnmatch
from typing import Optional, TypedDict, Union
from typing import Optional, TypedDict

from django.conf import settings
from django.db import models
Expand All @@ -20,10 +21,24 @@ class QuerySource(TypedDict):
instance_key: Optional[str] # e.g. `User:123`


# None means not initialized
# bool means initialized, in/not in zealot context
_is_in_context: ContextVar[Union[None, bool]] = ContextVar(
"in_context", default=None
# tuple of (model, field, caller)
CountsKey = tuple[type[models.Model], str, str]


@dataclass
class NPlusOneContext:
# None means not initialized
# bool means initialized, in/not in zealot context
is_in_context: Optional[bool] = None
counts: dict[CountsKey, int] = field(
default_factory=lambda: defaultdict(int)
)
ignored: set[str] = field(default_factory=set)


_nplusone_context: ContextVar[NPlusOneContext] = ContextVar(
"nplusone",
default=NPlusOneContext(),
)

logger = logging.getLogger("zealot")
Expand All @@ -38,9 +53,6 @@ class Listener(ABC):
@abstractmethod
def notify(self, *args, **kwargs): ...

@abstractmethod
def reset(self): ...

@property
@abstractmethod
def error_class(self) -> type[ZealotError]: ...
Expand Down Expand Up @@ -80,12 +92,6 @@ def _alert(self, model: type[models.Model], field: str, message: str):


class NPlusOneListener(Listener):
ignored_instances: set[str]
counts: dict[tuple[type[models.Model], str, str], int]

def __init__(self):
self.reset()

@property
def error_class(self):
return NPlusOneError
Expand All @@ -96,19 +102,18 @@ def notify(
field: str,
instance_key: Optional[str],
):
if not _is_in_context.get():
context = _nplusone_context.get()
if not context.is_in_context:
return

caller = get_caller()
key = (model, field, f"{caller.filename}:{caller.lineno}")
self.counts[key] += 1
count = self.counts[key]
if (
count >= self._threshold
and instance_key not in self.ignored_instances
):
context.counts[key] += 1
count = context.counts[key]
if count >= self._threshold and instance_key not in context.ignored:
message = f"N+1 detected on {model.__name__}.{field}"
self._alert(model, field, message)
_nplusone_context.set(context)

def ignore(self, instance_key: Optional[str]):
"""
Expand All @@ -117,15 +122,13 @@ def ignore(self, instance_key: Optional[str]):
This is used when the given instance is singly-loaded, e.g. via `.first()`
or `.get()`. This is to prevent false positives.
"""
if not _is_in_context.get():
context = _nplusone_context.get()
if not context.is_in_context:
return
if not instance_key:
return
self.ignored_instances.add(instance_key)

def reset(self):
self.counts = defaultdict(int)
self.ignored_instances = set()
context.ignored.add(instance_key)
_nplusone_context.set(context)

@property
def _threshold(self) -> int:
Expand All @@ -138,21 +141,25 @@ def _threshold(self) -> int:
n_plus_one_listener = NPlusOneListener()


def setup() -> Token:
new_context_value = True
if _is_in_context.get() is False:
# if we're already in an ignore-context, we don't want to override
# it.
new_context_value = False
return _is_in_context.set(new_context_value)
def setup() -> Optional[Token]:
# if we're already in an ignore-context, we don't want to override
# it.
context = _nplusone_context.get()
if context.is_in_context is False:
new_is_in_context = False
else:
new_is_in_context = True

return _nplusone_context.set(
NPlusOneContext(is_in_context=new_is_in_context)
)


def teardown(token: Optional[Token] = None):
n_plus_one_listener.reset()
if token:
_is_in_context.reset(token)
_nplusone_context.reset(token)
else:
_is_in_context.set(False)
_nplusone_context.set(NPlusOneContext())


@contextmanager
Expand All @@ -166,8 +173,14 @@ def zealot_context():

@contextmanager
def zealot_ignore():
token = _is_in_context.set(False)
old_context = _nplusone_context.get()
new_context = NPlusOneContext(
counts=old_context.counts.copy(),
ignored=old_context.ignored.copy(),
is_in_context=False,
)
token = _nplusone_context.set(new_context)
try:
yield
finally:
_is_in_context.reset(token)
_nplusone_context.reset(token)
53 changes: 47 additions & 6 deletions tests/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import pytest
from djangoproject.social.models import Post, User
from zealot import zealot_context, zealot_ignore
from zealot.errors import NPlusOneError
from zealot import NPlusOneError, zealot_context, zealot_ignore
from zealot.listeners import _nplusone_context, n_plus_one_listener

from .factories import PostFactory, UserFactory

Expand Down Expand Up @@ -105,11 +105,52 @@ def test_ignore_context_takes_precedence():
_ = list(user.posts.all())


def test_ignores_calls_on_different_lines():
def test_reverts_to_previous_state_when_leaving_zealot_ignore():
# we are currently in a zealot context
assert _nplusone_context.get().is_in_context is True
with zealot_ignore():
assert _nplusone_context.get().is_in_context is False
assert _nplusone_context.get().is_in_context is True

# if we start off *without* being in a context, that also gets reset
context = _nplusone_context.get()
context.is_in_context = None
_nplusone_context.set(context)

assert _nplusone_context.get().is_in_context is None
with zealot_ignore():
assert _nplusone_context.get().is_in_context is False
assert _nplusone_context.get().is_in_context is None


def test_resets_state_in_nested_context():
[user_1, user_2] = UserFactory.create_batch(2)
PostFactory.create(author=user_1)
PostFactory.create(author=user_2)

# this should *not* raise an exception
_a = list(user_1.posts.all())
_b = list(user_2.posts.all())
# we're already in a zealot_context within each test, so let's set
# some state.
n_plus_one_listener.ignore("Test:1")
n_plus_one_listener.notify(Post, "test_field", "Post:1")

context = _nplusone_context.get()
assert context.ignored == {"Test:1"}
assert list(context.counts.values()) == [1]

with zealot_context():
# new context, fresh state
context = _nplusone_context.get()
assert context.ignored == set()
assert list(context.counts.values()) == []

n_plus_one_listener.ignore("NestedTest:1")
n_plus_one_listener.notify(Post, "nested_test_field", "Post:1")

context = _nplusone_context.get()
assert context.ignored == {"NestedTest:1"}
assert list(context.counts.values()) == [1]

# back outside the nested context, we're back to the old state
context = _nplusone_context.get()
assert context.ignored == {"Test:1"}
assert list(context.counts.values()) == [1]
10 changes: 10 additions & 0 deletions tests/test_nplusones.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,13 @@ def test_works_in_web_requests(client):
assert response.status_code == 200
response = client.get(f"/user/{user_2.pk}/")
assert response.status_code == 200


def test_ignores_calls_on_different_lines():
[user_1, user_2] = UserFactory.create_batch(2)
PostFactory.create(author=user_1)
PostFactory.create(author=user_2)

# this should *not* raise an exception
_a = list(user_1.posts.all())
_b = list(user_2.posts.all())

0 comments on commit ff8a068

Please sign in to comment.