Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shared_db_wrapper for creating long-lived db state #258

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
20 changes: 20 additions & 0 deletions docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,26 @@ transaction support. This is only required for fixtures which need
database access themselves. A test function would normally use the
:py:func:`~pytest.mark.django_db` mark to signal it needs the database.

``shared_db_wrapper``
~~~~~~~~~~~~~~~~~~~~~

This fixture can be used to create long-lived state in the database.
It's meant to be used from fixtures with scope bigger than ``function``.
It provides a context manager that will create a new database savepoint for you,
and will take care to revert it when your fixture gets cleaned up.

At the moment it does not work with ``transactional_db``,
as the fixture itself depends on transactions.
It also needs Django >= 1.8, as earlier versions close DB connections between tests.

Example usage::

@pytest.fixture(scope='module')
def some_users(request, shared_db_wrapper):
with shared_db_wrapper(request):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API adds some boilerplate. I'd love to write just:

@pytest.fixture(scope='module')
def some_users(shared_db_wrapper):
    User.objects.create(username='foo')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like that too, but there is an issue with this. When you specify a dependency to some_users, the shared_db_wrapper fixture will still be active. This is a problem, because you only want to use global state for a limited amount of data (e.g. creating the users and not modifying them).

I think that is why OP has called it a shared_db_wrapper and not shared_db.

However, maybe there's a solution around this. Maybe it's possible to somehow create a special "destructor" before entering "local" scope. However this might be complicated given how the current architecture (I have no idea about pytest internals).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But to at least simplify it a bit: Why do we need access to the request fixture?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need request to register a finalizer that will run when the shared fixture (and not just the wrapper) is getting torn down.

Would it be desirable to put all that boilerplate in a decorator that would be used instead of (or maybe before?) pytest.fixture? Something like:

@shared_db_fixture(scope='module')
def some_users():
    User.objects.create(username='foo')

I don't think it's possible to hook into fixture instantiation in py.test, so we can't just use a marker on those fixture and wrap them later.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit that the interactions between fixtures of different scopes are totally over my head :-/

The end goal is to call transaction.atomic and transaction.set_rollback as follows -- but I don't know how to map this to pytest's hooks and to make it work within pytest-django's existing infrastructure.

with transaction.atomic():
    fixture_1()     # used in the entire test suite

    with transaction.atomic():
        fixture_2()     # used in one module

        with transaction.atomic():
            test_21()
            transaction.set_rollback(True)

        with transaction.atomic():
            test_22()
            transaction.set_rollback(True)

        with transaction.atomic():
            test_23()
            transaction.set_rollback(True)

        transaction.set_rollback(True)  # fixture 2

    with transaction.atomic():
        fixture_3()     # used in two modules

        with transaction.atomic():
            fixture_4()     # used in one module which also uses 3

            with transaction.atomic():
                test_41()
                transaction.set_rollback(True)

            with transaction.atomic():
                test_42()
                transaction.set_rollback(True)

            with transaction.atomic():
                test_43()
                transaction.set_rollback(True)

            transaction.set_rollback(True)  # fixture 4

        with transaction.atomic():
            fixture_5()     # used in one module which also uses 3

            with transaction.atomic():
                test_51()
                transaction.set_rollback(True)

            with transaction.atomic():
                test_52()
                transaction.set_rollback(True)

            with transaction.atomic():
                test_53()
                transaction.set_rollback(True)

                transaction.set_rollback(True)

            transaction.set_rollback(True)  # fixture 5

        transaction.set_rollback(True)  # fixture 3

    transaction.set_rollback(True)  # fixture 1

If multiple databases are being used, you need to call transaction.atomic(using=...) and transaction.set_rollback(True, using=...) for each of them.

Ping me if you have questions about Django's transaction management — I wrote it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't use with atomic(): because I have to register a callback on request :-/.

Thank you for pointing out transaction.set_rollback(True) - this will make my code cleaner (no more dummy exceptions!).

For the using=... part I might even be able to use Django's test case enter/exit atomics, or at least get some inspiration from there.

return [User.objects.create(username='no {}'.format(i))
for i in range(1000)]

``live_server``
~~~~~~~~~~~~~~~

Expand Down
57 changes: 56 additions & 1 deletion pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import with_statement

from contextlib import contextmanager
import os
import sys
import warnings

import pytest
Expand All @@ -13,7 +15,8 @@
from .django_compat import is_django_unittest
from .lazy_django import get_django_version, skip_if_no_django

__all__ = ['_django_db_setup', 'db', 'transactional_db', 'admin_user',
__all__ = ['_django_db_setup', 'db', 'transactional_db', 'shared_db_wrapper',
'admin_user',
'django_user_model', 'django_username_field',
'client', 'admin_client', 'rf', 'settings', 'live_server',
'_live_server_helper']
Expand Down Expand Up @@ -195,6 +198,58 @@ def transactional_db(request, _django_db_setup, _django_cursor_wrapper):
return _django_db_fixture_helper(True, request, _django_cursor_wrapper)


@pytest.fixture(scope='session')
def shared_db_wrapper(_django_db_setup, _django_cursor_wrapper):
"""Wrapper for common database initialization code.

This fixture provides a context manager that let's you access the database
from a transaction spanning multiple tests.
"""
from django.db import connection, transaction

if get_django_version() < (1, 8):
raise Exception('shared_db_wrapper is only supported on Django >= 1.8.')

class DummyException(Exception):
"""Dummy for use with Atomic.__exit__."""

@contextmanager
def wrapper(request):
# We need to take the request
# to bind finalization to the place where this is used
if 'transactional_db' in request.funcargnames:
raise Exception(
'shared_db_wrapper cannot be used with `transactional_db`.')

with _django_cursor_wrapper:
if not connection.features.supports_transactions:
raise Exception(
"shared_db_wrapper cannot be used when "
"the database doesn't support transactions.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that the next 25 lines could be written as:

        try:
            _django_cursor_wrapper.enable()
            with transaction.atomic():
                yield
                transaction.set_rollback(True)
        finally:
            _django_cursor_wrapper.restore()

This raises the question of why this fixture needs to care about _django_cursor_wrapper; it's unclear to me why this is necessary. If it isn't, the code can be further simplified to:

        with transaction.atomic():
            yield
            transaction.set_rollback(True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ktosiek What do you think about that?

# Use atomic instead of calling .savepoint* directly.
# This way works for both top-level transactions and "subtransactions".
atomic = transaction.atomic()

def finalize():
# dummy exception makes `atomic` rollback the savepoint
with _django_cursor_wrapper:
atomic.__exit__(DummyException, DummyException(), None)

try:
_django_cursor_wrapper.enable()
atomic.__enter__()
yield
request.addfinalizer(finalize)
except:
atomic.__exit__(*sys.exc_info())
raise
finally:
_django_cursor_wrapper.restore()

return wrapper


@pytest.fixture()
def client():
"""A Django test client instance."""
Expand Down
6 changes: 3 additions & 3 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from .django_compat import is_django_unittest
from .fixtures import (_django_db_setup, _live_server_helper, admin_client,
admin_user, client, db, django_user_model,
django_username_field, live_server, rf, settings,
transactional_db)
django_username_field, live_server, rf, shared_db_wrapper,
settings, transactional_db)
from .lazy_django import django_settings_is_configured, skip_if_no_django

# Silence linters for imported fixtures.
(_django_db_setup, _live_server_helper, admin_client, admin_user, client, db,
django_user_model, django_username_field, live_server, rf, settings,
transactional_db)
shared_db_wrapper, transactional_db)


SETTINGS_MODULE_ENV = 'DJANGO_SETTINGS_MODULE'
Expand Down
72 changes: 72 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db import connection, transaction
from django.test.testcases import connections_support_transactions

from pytest_django.lazy_django import get_django_version
from pytest_django_test.app.models import Item


Expand Down Expand Up @@ -51,6 +52,77 @@ def test_noaccess_fixture(noaccess):
pass


@pytest.mark.skipif(
get_django_version() < (1, 8),
reason="shared_db_wrapper needs at least Django 1.8")
def test_shared_db_wrapper(django_testdir):
django_testdir.create_test_module('''
from .app.models import Item
import pytest
from uuid import uuid4

@pytest.fixture(scope='session')
def session_item(request, shared_db_wrapper):
with shared_db_wrapper(request):
return Item.objects.create(name='session-' + uuid4().hex)

@pytest.fixture(scope='module')
def module_item(request, shared_db_wrapper):
with shared_db_wrapper(request):
return Item.objects.create(name='module-' + uuid4().hex)

@pytest.fixture(scope='class')
def class_item(request, shared_db_wrapper):
with shared_db_wrapper(request):
return Item.objects.create(name='class-' + uuid4().hex)

@pytest.fixture
def function_item(db):
return Item.objects.create(name='function-' + uuid4().hex)

class TestItems:
def test_save_the_items(
self, session_item, module_item, class_item,
function_item, db):
global _session_item
global _module_item
global _class_item
assert session_item.pk
assert module_item.pk
assert class_item.pk
_session_item = session_item
_module_item = module_item
_class_item = class_item

def test_mixing_with_non_db_tests(self):
pass

def test_accessing_the_same_items(
self, db, session_item, module_item, class_item):
assert _session_item.name == session_item.name
Item.objects.get(pk=_session_item.pk)
assert _module_item.name == module_item.name
Item.objects.get(pk=_module_item.pk)
assert _class_item.name == class_item.name
Item.objects.get(pk=_class_item.pk)

def test_mixing_with_other_db_tests(db):
Item.objects.get(name=_module_item.name)
assert Item.objects.filter(name__startswith='function').count() == 0

class TestSharing:
def test_sharing_some_items(
self, db, session_item, module_item, class_item,
function_item):
assert _session_item.name == session_item.name
assert _module_item.name == module_item.name
assert _class_item.name != class_item.name
assert Item.objects.filter(name__startswith='function').count() == 1
''')
result = django_testdir.runpytest_subprocess('-v', '-s', '--reuse-db')
assert result.ret == 0


class TestDatabaseFixtures:
"""Tests for the db and transactional_db fixtures"""

Expand Down