Skip to content

Commit

Permalink
Add pytest fixture django_db_views_setup
Browse files Browse the repository at this point in the history
  • Loading branch information
BezBartek committed Nov 27, 2024
1 parent 21bbd7e commit 165ba65
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 8 deletions.
19 changes: 12 additions & 7 deletions django_db_views/autodetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def get_previous_view_models_state(self) -> dict:
view_models[key] = model_state
return view_models

def get_current_view_models(self):
@staticmethod
def get_current_view_models():
view_models = {}
for app_label, models in apps.all_models.items():
for model_name, model_class in models.items():
Expand Down Expand Up @@ -229,15 +230,17 @@ def generate_views_operations(self, graph: MigrationGraph) -> None:
dependencies=dependencies,
)

def get_forward_migration_class(self, model) -> Type[ForwardViewMigrationBase]:
@staticmethod
def get_forward_migration_class(model) -> Type[ForwardViewMigrationBase]:
if issubclass(model, DBMaterializedView):
return ForwardMaterializedViewMigration
if issubclass(model, DBView):
return ForwardViewMigration
else:
raise NotImplementedError

def get_backward_migration_class(self, model) -> Type[BackwardViewMigrationBase]:
@staticmethod
def get_backward_migration_class(model) -> Type[BackwardViewMigrationBase]:
if issubclass(model, DBMaterializedView):
return BackwardMaterializedViewMigration
if issubclass(model, DBView):
Expand All @@ -253,7 +256,8 @@ def get_drop_migration_class(self, model) -> Type[DropViewMigration]:
else:
raise NotImplementedError

def get_view_definition_from_model(self, view_model: DBView) -> dict:
@classmethod
def get_view_definition_from_model(cls, view_model: DBView) -> dict:
view_definitions = {}
if callable(view_model.view_definition):
raw_view_definition = view_model.view_definition()
Expand All @@ -262,12 +266,12 @@ def get_view_definition_from_model(self, view_model: DBView) -> dict:

if isinstance(raw_view_definition, dict):
for engine, definition in raw_view_definition.items():
view_definitions[engine] = self.get_cleaned_view_definition_value(
view_definitions[engine] = cls.get_cleaned_view_definition_value(
definition
)
else:
engine = settings.DATABASES["default"]["ENGINE"]
view_definitions[engine] = self.get_cleaned_view_definition_value(
view_definitions[engine] = cls.get_cleaned_view_definition_value(
raw_view_definition
)
return view_definitions
Expand Down Expand Up @@ -337,7 +341,8 @@ def _get_view_identifiers_from_operation(self, operation) -> tuple[str, str]:
engine = settings.DATABASES["default"]["ENGINE"]
return table_name, engine

def get_cleaned_view_definition_value(self, view_definition: str) -> str:
@staticmethod
def get_cleaned_view_definition_value(view_definition: str) -> str:
assert isinstance(
view_definition, str
), "View definition must be callable and return string or be itself a string."
Expand Down
3 changes: 2 additions & 1 deletion django_db_views/db_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __new__(cls, *args, **kwargs):
assert (
new_class._meta.managed is False
), "For DB View managed must be set to false"
DBViewsRegistry[new_class._meta.db_table] = new_class
if not new_class._meta.abstract:
DBViewsRegistry[new_class._meta.db_table] = new_class
return new_class


Expand Down
58 changes: 58 additions & 0 deletions django_db_views/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from django.apps import apps
from django.db import connection

from django_db_views.autodetector import ViewMigrationAutoDetector

try:
import pytest
except ImportError:
raise Exception("fixtures are available only for pytest.")


@pytest.fixture(scope="session")
def django_db_views_setup(
request,
django_db_blocker,
django_db_use_migrations: bool,
django_db_keepdb: bool,
) -> None:
def no_migrations_tear_up() -> None:
view_models = ViewMigrationAutoDetector.get_current_view_models()
with django_db_blocker.unblock(), connection.schema_editor() as schema_editor:
engine = schema_editor.connection.settings_dict["ENGINE"]
for view_model in view_models:
view_definition = (
ViewMigrationAutoDetector.get_view_definition_from_model(
view_model
)[engine]
)
forward_migration = (
ViewMigrationAutoDetector.get_forward_migration_class(view_model)(
view_definition.strip(";"),
view_model._meta.db_table,
engine=engine,
)
)
# run migration
forward_migration(apps, schema_editor)

def no_migrations_teardown() -> None:
view_models = ViewMigrationAutoDetector.get_current_view_models()
with django_db_blocker.unblock(), connection.schema_editor() as schema_editor:
engine = schema_editor.connection.settings_dict["ENGINE"]
for view_model in view_models:
backward_migration = (
ViewMigrationAutoDetector.get_backward_migration_class(view_model)(
"",
view_model._meta.db_table,
engine=engine,
)
)
# run migration
backward_migration(apps, schema_editor)

if not django_db_use_migrations:
no_migrations_tear_up()

if not django_db_keepdb:
request.addfinalizer(no_migrations_teardown)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ include = [

[tool.pytest.ini_options]
addopts = "--migrations --create-db"
env = [
"PYTHONDONTWRITEBYTECODE = 1"
]

[tool.tox]
legacy_tox_ini = """
Expand Down

0 comments on commit 165ba65

Please sign in to comment.