Skip to content

Commit

Permalink
feat: ensure errors have useful error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
taobojlen committed Jun 30, 2024
1 parent 5a4f8a1 commit 1b10952
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/queryspy/listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def notify(self, model: Type[models.Model], field: str):
self.counts[key] += 1
count = self.counts[key]
if count >= threshold:
raise NPlusOneError("BAD!")
raise NPlusOneError(f"N+1 detected on {model.__name__}.{field}")

def reset(self):
self.counts = defaultdict(int)
Expand Down
31 changes: 24 additions & 7 deletions src/queryspy/patch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import importlib
import inspect
from typing import Any, Callable, NotRequired, TypedDict, Unpack
from typing import Any, Callable, NotRequired, Type, TypedDict, Unpack

from django.db import models
from django.db.models.fields.related_descriptors import (
Expand Down Expand Up @@ -93,16 +93,33 @@ def wrapper(*args, **kwargs):
)


def parse_related_parts(
model: Type[models.Model],
related_name: str | None,
related_model: Type[models.Model],
) -> ModelAndField:
field_name = related_name or f"{related_model._meta.model_name}_set"
return (model, field_name)


def patch_reverse_many_to_one_descriptor():
def parser(context: QuerysetContext):
assert "args" in context
field = context["args"][0]
return (field.model, field.name)
assert (
"manager_call_args" in context
and "rel" in context["manager_call_args"]
)
rel = context["manager_call_args"]["rel"]
return parse_related_parts(
rel.model, rel.related_name, rel.related_model
)

def patched_create_reverse_many_to_one_manager(*args, **kwargs):
manager_call_args = inspect.getcallargs(
create_reverse_many_to_one_manager, *args, **kwargs
)
manager = create_reverse_many_to_one_manager(*args, **kwargs)
manager.get_queryset = patch_queryset_function(
manager.get_queryset, parser
manager.get_queryset, parser, manager_call_args=manager_call_args
)
return manager

Expand All @@ -117,7 +134,7 @@ def parser(context: QuerysetContext):
assert "args" in context
descriptor = context["args"][0]
field = descriptor.related.field
return (field.model, field.name)
return (field.related_model, field.remote_field.name)

ReverseOneToOneDescriptor.get_queryset = patch_queryset_function(
ReverseOneToOneDescriptor.get_queryset, parser
Expand All @@ -137,7 +154,7 @@ def parser(context: QuerysetContext):
related_model = manager.target_field.related_model
field_name = manager.prefetch_cache_name if rel.related_name else None

return (model, field_name or f"{related_model._meta.model_name}_set")
return parse_related_parts(model, field_name, related_model)

def patched_create_forward_many_to_many_manager(*args, **kwargs):
manager_call_args = inspect.getcallargs(
Expand Down
35 changes: 26 additions & 9 deletions tests/test_nplusones.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest
from djangoproject.social.models import Post, Profile, User
from queryspy import NPlusOneError, queryspy_context
Expand All @@ -12,7 +14,9 @@ def test_detects_nplusone_in_forward_many_to_one():
[user_1, user_2] = UserFactory.create_batch(2)
PostFactory.create(author=user_1)
PostFactory.create(author=user_2)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on Post.author")
):
for post in Post.objects.all():
_ = post.author.username

Expand All @@ -25,7 +29,9 @@ def test_detects_nplusone_in_reverse_many_to_one():
[user_1, user_2] = UserFactory.create_batch(2)
PostFactory.create(author=user_1)
PostFactory.create(author=user_2)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on User.posts")
):
for user in User.objects.all():
_ = list(user.posts.all())

Expand All @@ -38,7 +44,9 @@ def test_detects_nplusone_in_forward_one_to_one():
[user_1, user_2] = UserFactory.create_batch(2)
ProfileFactory.create(user=user_1)
ProfileFactory.create(user=user_2)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on Profile.user")
):
for profile in Profile.objects.all():
_ = profile.user.username

Expand All @@ -51,7 +59,9 @@ def test_detects_nplusone_in_reverse_one_to_one():
[user_1, user_2] = UserFactory.create_batch(2)
ProfileFactory.create(user=user_1)
ProfileFactory.create(user=user_2)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on User.profile")
):
for user in User.objects.all():
_ = user.profile.display_name

Expand All @@ -64,7 +74,9 @@ def test_detects_nplusone_in_forward_many_to_many():
[user_1, user_2] = UserFactory.create_batch(2)
user_1.following.add(user_2)
user_2.following.add(user_1)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on User.following")
):
for user in User.objects.all():
_ = list(user.following.all())

Expand All @@ -77,7 +89,9 @@ def test_detects_nplusone_in_reverse_many_to_many():
[user_1, user_2] = UserFactory.create_batch(2)
user_1.following.add(user_2)
user_2.following.add(user_1)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on User.followers")
):
for user in User.objects.all():
_ = list(user.followers.all())

Expand All @@ -90,7 +104,9 @@ def test_detects_nplusone_in_reverse_many_to_many_with_no_related_name():
[user_1, user_2] = UserFactory.create_batch(2)
user_1.blocked.add(user_2)
user_2.blocked.add(user_1)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on User.user_set")
):
for user in User.objects.all():
_ = list(user.user_set.all())

Expand All @@ -103,7 +119,9 @@ def test_detects_nplusone_due_to_deferred_fields():
[user_1, user_2] = UserFactory.create_batch(2)
PostFactory.create(author=user_1)
PostFactory.create(author=user_2)
with pytest.raises(NPlusOneError):
with pytest.raises(
NPlusOneError, match=re.escape("N+1 detected on User.username")
):
for post in (
Post.objects.all().select_related("author").only("author__id")
):
Expand Down Expand Up @@ -145,7 +163,6 @@ def test_works_in_web_requests(client):
ProfileFactory.create(user=user_2)
with pytest.raises(NPlusOneError):
response = client.get("/users/")
assert response.status_code == 500

# but multiple requests work fine
response = client.get(f"/user/{user_1.pk}/")
Expand Down

0 comments on commit 1b10952

Please sign in to comment.