diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..5c9b7b8 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,24 @@ +name: Test +on: [push] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12"] + django-version: ["4.2", "5.0"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + # we don't actually install from requirements.txt because the only dependency + # (currently) is django + - run: pip install Django~=${{ matrix.django-version }} + - run: make test + + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9ac2dd8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..1c4ce99 --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +python 3.12.4 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8f3179f --- /dev/null +++ b/LICENSE @@ -0,0 +1,20 @@ +Copyright 2016 Joshua Carp +Copyright 2024 Tao Bojlén + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ae78220 --- /dev/null +++ b/Makefile @@ -0,0 +1,5 @@ +install: + uv pip compile requirements.in -o requirements.txt && uv pip compile requirements-dev.in -o requirements-dev.txt && uv pip install -r requirements.txt && uv pip install -r requirements-dev.txt + +test: + pytest -s diff --git a/README.md b/README.md new file mode 100644 index 0000000..aef98af --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# queryspy + +This library catches N+1s in your Django project. + +## Features + +- Raises an error when N+1s are detected +- TODO: configurable thresholds +- TODO: allowlist +- TODO: catches unused eager loads +- Well-tested +- No dependencies + +## Acknowledgements + +This library draws very heavily on jmcarp's [nplusone](https://github.com/jmcarp/nplusone/). +It's not *exactly* a fork, but not far from it. + +## Installation + +TODO. + +## Contributing + +1. First, install [uv](https://github.com/astral-sh/uv). +2. Create a virtual env using `uv venv` and activate it with `source .venv/bin/activate`. +3. Run `make install` to install dev dependencies. +4. To run tests, run `make test`. + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..12b2b87 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[tool.ruff] +line-length = 79 + +[tool.ruff.lint] +extend-select = [ + "I", # isort + "N", # naming + "B", # bugbear + "FIX", # disallow FIXME/TODO comments + "F", # pyflakes +] + +[tool.pytest.ini_options] +DJANGO_SETTINGS_MODULE = "djangoproject.settings" +pythonpath = ["src", "tests"] +testpaths = ["tests"] +addopts = "--nomigrations" diff --git a/requirements-dev.in b/requirements-dev.in new file mode 100644 index 0000000..9165850 --- /dev/null +++ b/requirements-dev.in @@ -0,0 +1,7 @@ +# ensure that dev deps are constrained to production deps +-c requirements.txt + +pytest~=8.2.2 +pytest-django~=4.8.0 +factory-boy~=3.3.0 +ruff~=0.5.0 diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..6a5ca3b --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,24 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements-dev.in -o requirements-dev.txt +factory-boy==3.3.0 + # via -r requirements-dev.in +faker==26.0.0 + # via factory-boy +iniconfig==2.0.0 + # via pytest +packaging==24.1 + # via pytest +pluggy==1.5.0 + # via pytest +pytest==8.2.2 + # via + # -r requirements-dev.in + # pytest-django +pytest-django==4.8.0 + # via -r requirements-dev.in +python-dateutil==2.9.0.post0 + # via faker +ruff==0.5.0 + # via -r requirements-dev.in +six==1.16.0 + # via python-dateutil diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..a0b0f80 --- /dev/null +++ b/requirements.in @@ -0,0 +1 @@ +Django~=5.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d62a8aa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o requirements.txt +asgiref==3.8.1 + # via django +django==5.0.6 + # via -r requirements.in +sqlparse==0.5.0 + # via django diff --git a/src/queryspy/__init__.py b/src/queryspy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/queryspy/apps.py b/src/queryspy/apps.py new file mode 100644 index 0000000..123ded2 --- /dev/null +++ b/src/queryspy/apps.py @@ -0,0 +1,13 @@ +from django.apps import AppConfig + +from .patch import patch + +# from . import ugh + + +class QuerySpyConfig(AppConfig): + name = "queryspy" + + def ready(self): + patch() + pass diff --git a/src/queryspy/errors.py b/src/queryspy/errors.py new file mode 100644 index 0000000..2a8503b --- /dev/null +++ b/src/queryspy/errors.py @@ -0,0 +1,6 @@ +class QuerySpyError(Exception): + pass + + +class NPlusOneError(QuerySpyError): + pass diff --git a/src/queryspy/listeners.py b/src/queryspy/listeners.py new file mode 100644 index 0000000..9d78ff1 --- /dev/null +++ b/src/queryspy/listeners.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Type + +from django.db import models + +from .errors import NPlusOneError + +ModelAndField = tuple[Type[models.Model], str] +THRESHOLD = 2 + + +class Listener(ABC): + @abstractmethod + def notify(self, *args, **kwargs): ... + + +class NPlusOneListener(Listener): + counts: dict[ModelAndField, int] + + def __init__(self): + self.counts = defaultdict(int) + + def notify(self, model: Type[models.Model], field: str): + key = (model, field) + self.counts[key] += 1 + count = self.counts[key] + if count >= THRESHOLD: + raise NPlusOneError("BAD!") + + +n_plus_one_listener = NPlusOneListener() diff --git a/src/queryspy/patch.py b/src/queryspy/patch.py new file mode 100644 index 0000000..0456625 --- /dev/null +++ b/src/queryspy/patch.py @@ -0,0 +1,159 @@ +import functools +import importlib +import inspect +from typing import Any, Callable, NotRequired, TypedDict, Unpack + +from django.db import models +from django.db.models.fields.related_descriptors import ( + ForwardManyToOneDescriptor, + ReverseOneToOneDescriptor, + create_forward_many_to_many_manager, + create_reverse_many_to_one_manager, +) + +from .listeners import ModelAndField, n_plus_one_listener + + +class QuerysetContext(TypedDict): + args: NotRequired[Any] + kwargs: NotRequired[Any] + + # This is only used for many-to-many relations. It contains the call args + # when `create_forward_many_to_many_manager` is called. + manager_call_args: NotRequired[dict[str, Any]] + + +Parser = Callable[[QuerysetContext], ModelAndField] + + +def patch_module_function(original, patched): + module = importlib.import_module(original.__module__) + setattr(module, original.__name__, patched) + + +def patch_queryset_fetch_all( + queryset: models.QuerySet, parser: Parser, context: QuerysetContext +): + fetch_all = queryset._fetch_all + + @functools.wraps(fetch_all) + def wrapper(*args, **kwargs): + if queryset._result_cache is None: + n_plus_one_listener.notify(*parser(context)) + return fetch_all(*args, **kwargs) + + return wrapper + + +def patch_queryset_function( + queryset_func: Callable[..., models.QuerySet], + parser: Parser, + **context: Unpack[QuerysetContext], +): + @functools.wraps(queryset_func) + def wrapper(*args, **kwargs): + queryset = queryset_func(*args, **kwargs) + context["args"] = context.get("args", args) + context["kwargs"] = context.get("kwargs", kwargs) + queryset._clone = patch_queryset_function( + queryset._clone, parser, **context + ) + queryset._fetch_all = patch_queryset_fetch_all( + queryset, parser, context + ) + return queryset + + return wrapper + + +def patch_forward_many_to_one_descriptor(): + """ + This also handles ForwardOneToOneDescriptor, which is + a subclass of ForwardManyToOneDescriptor. + """ + + # in ForwardManyToOneDescriptor, get_object is only called when the related + # object is not prefetched + def patch_get_object(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + descriptor = args[0] + n_plus_one_listener.notify( + descriptor.field.model, descriptor.field.name + ) + return func(*args, **kwargs) + + return wrapper + + ForwardManyToOneDescriptor.get_object = patch_get_object( + ForwardManyToOneDescriptor.get_object + ) + + +def patch_reverse_many_to_one_descriptor(): + def parser(context: QuerysetContext): + assert "args" in context + field = context["args"][0] + return (field.model, field.name) + + def patched_create_reverse_many_to_one_manager(*args, **kwargs): + manager = create_reverse_many_to_one_manager(*args, **kwargs) + manager.get_queryset = patch_queryset_function( + manager.get_queryset, parser + ) + return manager + + patch_module_function( + create_reverse_many_to_one_manager, + patched_create_reverse_many_to_one_manager, + ) + + +def patch_reverse_one_to_one_descriptor(): + def parser(context: QuerysetContext): + assert "args" in context + descriptor = context["args"][0] + field = descriptor.related.field + return (field.model, field.name) + + ReverseOneToOneDescriptor.get_queryset = patch_queryset_function( + ReverseOneToOneDescriptor.get_queryset, parser + ) + + +def patch_many_to_many_descriptor(): + def parser(context: QuerysetContext): + assert ( + "manager_call_args" in context + and "rel" in context["manager_call_args"] + ) + assert "args" in context + rel = context["manager_call_args"]["rel"] + manager = context["args"][0] + model = manager.instance.__class__ + related_model = manager.target_field.related_model + field_name = manager.prefetch_cache_name if rel.related_name else None + + return (model, field_name or f"{related_model._meta.model_name}_set") + + def patched_create_forward_many_to_many_manager(*args, **kwargs): + manager_call_args = inspect.getcallargs( + create_forward_many_to_many_manager, *args, **kwargs + ) + manager = create_forward_many_to_many_manager(*args, **kwargs) + manager.get_queryset = patch_queryset_function( + manager.get_queryset, parser, manager_call_args=manager_call_args + ) + return manager + + patch_module_function( + create_forward_many_to_many_manager, + patched_create_forward_many_to_many_manager, + ) + + +def patch(): + patch_forward_many_to_one_descriptor() + patch_reverse_many_to_one_descriptor() + patch_reverse_one_to_one_descriptor() + patch_many_to_many_descriptor() diff --git a/src/queryspy/ugh.py b/src/queryspy/ugh.py new file mode 100644 index 0000000..4d459ba --- /dev/null +++ b/src/queryspy/ugh.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- + +import copy +import functools +import importlib +import inspect +import threading + +from django.db.models import Model, query +from django.db.models.fields.related_descriptors import ( + ForwardManyToOneDescriptor, + ReverseOneToOneDescriptor, + create_forward_many_to_many_manager, + create_reverse_many_to_one_manager, +) + + +def get_worker(): + return str(threading.current_thread().ident) + + +def to_key(instance): + model = type(instance) + return ":".join([model.__name__, format(instance.pk)]) + + +def patch(original, patched): + module = importlib.import_module(original.__module__) + setattr(module, original.__name__, patched) + + +def signalify_queryset(func, parser=None, **context): + @functools.wraps(func) + def wrapped(*args, **kwargs): + queryset = func(*args, **kwargs) + ctx = copy.copy(context) + ctx["args"] = context.get("args", args) + ctx["kwargs"] = context.get("kwargs", kwargs) + queryset._clone = signalify_queryset( + queryset._clone, parser=parser, **ctx + ) + queryset._fetch_all = signalify_fetch_all( + queryset, parser=parser, **ctx + ) + queryset._context = ctx + return queryset + + return wrapped + + +def signalify_fetch_all(queryset, parser=None, **context): + """Signal lazy load when `QuerySet._fetch_all` fetches rows. Note: patch + `_fetch_all` instead of `iterator` since, as of Django 1.11, the former is + used for all fetches while the latter is not. + """ + func = queryset._fetch_all + + @functools.wraps(func) + def wrapped(*args, **kwargs): + return func(*args, **kwargs) + + return wrapped + + +def get_related_name(model): + return "{0}_set".format(model._meta.model_name) + + +def parse_field(field): + return ( + ( + field.related_model # Django >= 1.8 + if hasattr(field, "related_model") + else field.related_field.model # Django <= 1.8 + ), + ( + field.remote_field.name # Django >= 1.8 + if hasattr(field, "remote_field") + else field.rel.related_name # Django <= 1.8 + ) + or get_related_name(field.related_model), + ) + + +def parse_reverse_field(field): + return field.model, field.name + + +def parse_related(context): + if "rel" in context: # pragma: no cover + rel = context["rel"] + return parse_related_parts( + rel.model, rel.related_name, rel.related_model + ) + else: # pragma: no cover + field = context["rel_field"] + model = field.related_field.model + related_name = field.rel.related_name + related_model = context["rel_model"] + return parse_related_parts(model, related_name, related_model) + + +def parse_related_parts(model, related_name, related_model): + return ( + model, + related_name or get_related_name(related_model), + ) + + +def parse_reverse_one_to_one_queryset(args, kwargs, context): + descriptor = context["args"][0] + field = descriptor.related.field + model, name = parse_field(field) + instance = context["kwargs"]["instance"] + return model, to_key(instance), name + + +def parse_forward_many_to_one_queryset(args, kwargs, context): + descriptor = context["args"][0] + instance = context["kwargs"]["instance"] + return descriptor.field.model, to_key(instance), descriptor.field.name + + +def parse_many_related_queryset(args, kwargs, context): + rel = context["rel"] + manager = context["args"][0] + model = manager.instance.__class__ + related_model = ( + manager.target_field.related_model # Django >= 1.8 + if hasattr(manager.target_field, "related_model") + else manager.target_field.related_field.model # Django <= 1.8 + ) + field = manager.prefetch_cache_name if rel.related_name else None + return ( + model, + to_key(manager.instance), + field or get_related_name(related_model), + ) + + +def parse_foreign_related_queryset(args, kwargs, context): + model, name = parse_related(context) + descriptor = context["args"][0] + return model, to_key(descriptor.instance), name + + +def parse_get(args, kwargs, context, ret): + return [to_key(ret)] if isinstance(ret, Model) else [] + + +# Ignore records loaded during `get` + +ReverseOneToOneDescriptor.get_queryset = signalify_queryset( + ReverseOneToOneDescriptor.get_queryset, + parser=parse_reverse_one_to_one_queryset, +) +ForwardManyToOneDescriptor.get_queryset = signalify_queryset( + ForwardManyToOneDescriptor.get_queryset, + parser=parse_forward_many_to_one_queryset, +) + + +def _create_forward_many_to_many_manager(*args, **kwargs): + context = inspect.getcallargs( + create_forward_many_to_many_manager, *args, **kwargs + ) + manager = create_forward_many_to_many_manager(*args, **kwargs) + manager.get_queryset = signalify_queryset( + manager.get_queryset, parser=parse_many_related_queryset, **context + ) + return manager + + +patch( + create_forward_many_to_many_manager, _create_forward_many_to_many_manager +) + + +def _create_reverse_many_to_one_manager(*args, **kwargs): + context = inspect.getcallargs( + create_reverse_many_to_one_manager, *args, **kwargs + ) + manager = create_reverse_many_to_one_manager(*args, **kwargs) + + manager.get_queryset = signalify_queryset( + manager.get_queryset, parser=parse_foreign_related_queryset, **context + ) + return manager + + +patch(create_reverse_many_to_one_manager, _create_reverse_many_to_one_manager) + + +def parse_forward_many_to_one_get(args, kwargs, context): + descriptor, instance, _ = args + if instance is None: + return None + field, model = parse_reverse_field(descriptor.field) + return field, model, [to_key(instance)] + + +def parse_reverse_one_to_one_get(args, kwargs, context): + descriptor, instance = args[:2] + if instance is None: + return None + model, field = parse_field(descriptor.related.field) + return model, field, [to_key(instance)] + + +def parse_fetch_all(args, kwargs, context): + self = args[0] + if hasattr(self, "_context"): + manager = self._context["args"][0] + instance = manager.instance + # Handle iteration over many-to-many relationship + if manager.__class__.__name__ == "ManyRelatedManager": + return ( + instance.__class__, + parse_manager_field(manager, self._context["rel"]), + [to_key(instance)], + ) + # Handle iteration over one-to-many relationship + else: + model, field = parse_related(self._context) + return model, field, [to_key(instance)] + + +def parse_manager_field(manager, rel): + if manager.reverse: + return rel.related_name or get_related_name(rel.related_model) + return rel.field.name or get_related_name(rel.model) + + +def parse_load(args, kwargs, context, ret): + return [to_key(row) for row in ret if isinstance(row, Model)] + + +def is_single(low, high): + return high is not None and high - low == 1 + + +original_related_populator_init = query.RelatedPopulator.__init__ + + +def related_populator_init(self, *args, **kwargs): + original_related_populator_init(self, *args, **kwargs) + self.__nplusone__ = { + "args": args, + "kwargs": kwargs, + } + + +query.RelatedPopulator.__init__ = related_populator_init + + +def parse_eager_select(args, kwargs, context): + populator = args[0] + instance = args[2] + meta = populator.__nplusone__ + klass_info, select, _ = meta["args"] + field = klass_info["field"] + model, name = ( + parse_field(field) + if instance._meta.model != field.model + else parse_reverse_field(field) + ) + return model, name, [to_key(instance)], id(select) + + +def parse_eager_join(args, kwargs, context): + instances, descriptor, fetcher, level = args + model = instances[0].__class__ + field, _ = fetcher.get_current_to_attr(level) + keys = [to_key(instance) for instance in instances] + return model, field, keys, id(instances) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/djangoproject/__init__.py b/tests/djangoproject/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/djangoproject/manage.py b/tests/djangoproject/manage.py new file mode 100755 index 0000000..9313c7d --- /dev/null +++ b/tests/djangoproject/manage.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" + +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "djangoproject.settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/tests/djangoproject/settings.py b/tests/djangoproject/settings.py new file mode 100644 index 0000000..e126ff6 --- /dev/null +++ b/tests/djangoproject/settings.py @@ -0,0 +1,28 @@ +SECRET_KEY = 1 +DEBUG = True +USE_TZ = True +TIME_ZONE = "UTC" + +INSTALLED_APPS = [ + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "djangoproject.social", + "queryspy", +] + +MIDDLEWARE = [] + +ROOT_URLCONF = "djangoproject.urls" + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } +} + +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" diff --git a/tests/djangoproject/social/__init__.py b/tests/djangoproject/social/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/djangoproject/social/apps.py b/tests/djangoproject/social/apps.py new file mode 100644 index 0000000..11ae04a --- /dev/null +++ b/tests/djangoproject/social/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class SocialConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "djangoproject.social" diff --git a/tests/djangoproject/social/models.py b/tests/djangoproject/social/models.py new file mode 100644 index 0000000..825cda8 --- /dev/null +++ b/tests/djangoproject/social/models.py @@ -0,0 +1,27 @@ +from django.db import models + + +class User(models.Model): + username = models.TextField() + # user.followers and user.following are both ManyToManyDescriptor + following = models.ManyToManyField("User", related_name="followers") + + # note that there's no related_name set here, because we want to + # test that case too. + blocked = models.ManyToManyField("user") + + +class Profile(models.Model): + # profile.user is ForwardOneToOne + # user.profile is ReverseOneToOne + user = models.OneToOneField(User, on_delete=models.CASCADE) + display_name = models.TextField() + + +class Post(models.Model): + # post.author is ForwardManyToOne + # user.posts is ReverseManyToOne + author = models.ForeignKey( + User, on_delete=models.CASCADE, related_name="posts" + ) + text = models.TextField() diff --git a/tests/factories.py b/tests/factories.py new file mode 100644 index 0000000..919ecdd --- /dev/null +++ b/tests/factories.py @@ -0,0 +1,23 @@ +import factory +from djangoproject.social.models import Post, Profile, User + + +class UserFactory(factory.django.DjangoModelFactory): + username = factory.Faker("user_name") + + class Meta: + model = User + + +class ProfileFactory(factory.django.DjangoModelFactory): + display_name = factory.Faker("name") + + class Meta: + model = Profile + + +class PostFactory(factory.django.DjangoModelFactory): + text = factory.Faker("sentence") + + class Meta: + model = Post diff --git a/tests/test_nplusones.py b/tests/test_nplusones.py new file mode 100644 index 0000000..56237c5 --- /dev/null +++ b/tests/test_nplusones.py @@ -0,0 +1,131 @@ +import pytest +from djangoproject.social.models import Post, Profile, User +from queryspy.errors import NPlusOneError + +from .factories import PostFactory, ProfileFactory, UserFactory + +pytestmark = pytest.mark.django_db + + +def test_detects_nplusone_in_forward_many_to_one(): + [user_1, user_2] = UserFactory.create_batch(2) + PostFactory.create(author=user_1) + PostFactory.create(author=user_2) + with pytest.raises(NPlusOneError): + for post in Post.objects.all(): + _ = post.author.username + + +def test_detects_nplusone_in_reverse_many_to_one(): + [user_1, user_2] = UserFactory.create_batch(2) + PostFactory.create(author=user_1) + PostFactory.create(author=user_2) + with pytest.raises(NPlusOneError): + for user in User.objects.all(): + _ = list(user.posts.all()) + + +def test_detects_nplusone_in_forward_one_to_one(): + [user_1, user_2] = UserFactory.create_batch(2) + ProfileFactory.create(user=user_1) + ProfileFactory.create(user=user_2) + with pytest.raises(NPlusOneError): + for profile in Profile.objects.all(): + _ = profile.user.username + + +def test_detects_nplusone_in_reverse_one_to_one(): + [user_1, user_2] = UserFactory.create_batch(2) + ProfileFactory.create(user=user_1) + ProfileFactory.create(user=user_2) + with pytest.raises(NPlusOneError): + for user in User.objects.all(): + _ = user.profile.display_name + + +def test_detects_nplusone_in_forward_many_to_many(): + [user_1, user_2] = UserFactory.create_batch(2) + user_1.following.add(user_2) + user_2.following.add(user_1) + with pytest.raises(NPlusOneError): + for user in User.objects.all(): + _ = list(user.following.all()) + + +def test_detects_nplusone_in_reverse_many_to_many(): + [user_1, user_2] = UserFactory.create_batch(2) + user_1.following.add(user_2) + user_2.following.add(user_1) + with pytest.raises(NPlusOneError): + for user in User.objects.all(): + _ = list(user.followers.all()) + + +def test_detects_nplusone_in_reverse_many_to_many_with_no_related_name(): + [user_1, user_2] = UserFactory.create_batch(2) + user_1.blocked.add(user_2) + user_2.blocked.add(user_1) + with pytest.raises(NPlusOneError): + for user in User.objects.all(): + _ = list(user.user_set.all()) + + +# def test_detects_nplusone_due_to_deferred_fields(): +# pass + + +def test_does_not_raise_when_forward_many_to_one_prefetched(): + [user_1, user_2] = UserFactory.create_batch(2) + PostFactory.create(author=user_1) + PostFactory.create(author=user_2) + for post in Post.objects.select_related("author").all(): + _ = post.author.username + + +def test_does_not_raise_when_reverse_many_to_one_prefetched(): + [user_1, user_2] = UserFactory.create_batch(2) + PostFactory.create(author=user_1) + PostFactory.create(author=user_2) + for user in User.objects.prefetch_related("posts").all(): + _ = list(user.posts.all()) + + +def test_does_not_raise_when_forward_one_to_one_prefetched(): + [user_1, user_2] = UserFactory.create_batch(2) + ProfileFactory.create(user=user_1) + ProfileFactory.create(user=user_2) + + for profile in Profile.objects.select_related("user").all(): + _ = profile.user.username + + +def test_does_not_raise_when_reverse_one_to_one_prefetched(): + [user_1, user_2] = UserFactory.create_batch(2) + ProfileFactory.create(user=user_1) + ProfileFactory.create(user=user_2) + for user in User.objects.select_related("profile").all(): + _ = user.profile.display_name + + +def test_does_not_raise_when_forward_many_to_many_prefetched(): + [user_1, user_2] = UserFactory.create_batch(2) + user_1.following.add(user_2) + user_2.following.add(user_1) + for user in User.objects.prefetch_related("following").all(): + _ = list(user.following.all()) + + +def test_does_not_raise_when_reverse_many_to_many_prefetched(): + [user_1, user_2] = UserFactory.create_batch(2) + user_1.following.add(user_2) + user_2.following.add(user_1) + for user in User.objects.prefetch_related("followers").all(): + _ = list(user.followers.all()) + + +def test_does_not_raise_when_reverse_many_to_many_with_no_related_name_prefetched(): + [user_1, user_2] = UserFactory.create_batch(2) + user_1.blocked.add(user_2) + user_2.blocked.add(user_1) + for user in User.objects.prefetch_related("user_set").all(): + _ = list(user.user_set.all())