diff --git a/django_db_views/autodetector.py b/django_db_views/autodetector.py index 19c02a2..c3247e6 100644 --- a/django_db_views/autodetector.py +++ b/django_db_views/autodetector.py @@ -168,7 +168,8 @@ def get_previous_view_models_state(self) -> dict: view_models[key] = model_state return view_models - def get_current_view_models(self): + @staticmethod + def get_current_view_models(): view_models = {} for app_label, models in apps.all_models.items(): for model_name, model_class in models.items(): @@ -229,7 +230,8 @@ def generate_views_operations(self, graph: MigrationGraph) -> None: dependencies=dependencies, ) - def get_forward_migration_class(self, model) -> Type[ForwardViewMigrationBase]: + @staticmethod + def get_forward_migration_class(model) -> Type[ForwardViewMigrationBase]: if issubclass(model, DBMaterializedView): return ForwardMaterializedViewMigration if issubclass(model, DBView): @@ -237,7 +239,8 @@ def get_forward_migration_class(self, model) -> Type[ForwardViewMigrationBase]: else: raise NotImplementedError - def get_backward_migration_class(self, model) -> Type[BackwardViewMigrationBase]: + @staticmethod + def get_backward_migration_class(model) -> Type[BackwardViewMigrationBase]: if issubclass(model, DBMaterializedView): return BackwardMaterializedViewMigration if issubclass(model, DBView): @@ -253,7 +256,8 @@ def get_drop_migration_class(self, model) -> Type[DropViewMigration]: else: raise NotImplementedError - def get_view_definition_from_model(self, view_model: DBView) -> dict: + @classmethod + def get_view_definition_from_model(cls, view_model: DBView) -> dict: view_definitions = {} if callable(view_model.view_definition): raw_view_definition = view_model.view_definition() @@ -262,12 +266,12 @@ def get_view_definition_from_model(self, view_model: DBView) -> dict: if isinstance(raw_view_definition, dict): for engine, definition in raw_view_definition.items(): - view_definitions[engine] = self.get_cleaned_view_definition_value( + view_definitions[engine] = cls.get_cleaned_view_definition_value( definition ) else: engine = settings.DATABASES["default"]["ENGINE"] - view_definitions[engine] = self.get_cleaned_view_definition_value( + view_definitions[engine] = cls.get_cleaned_view_definition_value( raw_view_definition ) return view_definitions @@ -337,7 +341,8 @@ def _get_view_identifiers_from_operation(self, operation) -> tuple[str, str]: engine = settings.DATABASES["default"]["ENGINE"] return table_name, engine - def get_cleaned_view_definition_value(self, view_definition: str) -> str: + @staticmethod + def get_cleaned_view_definition_value(view_definition: str) -> str: assert isinstance( view_definition, str ), "View definition must be callable and return string or be itself a string." diff --git a/django_db_views/db_view.py b/django_db_views/db_view.py index 07c7b38..7d6cc77 100644 --- a/django_db_views/db_view.py +++ b/django_db_views/db_view.py @@ -12,7 +12,8 @@ def __new__(cls, *args, **kwargs): assert ( new_class._meta.managed is False ), "For DB View managed must be set to false" - DBViewsRegistry[new_class._meta.db_table] = new_class + if not new_class._meta.abstract: + DBViewsRegistry[new_class._meta.db_table] = new_class return new_class diff --git a/django_db_views/fixtures.py b/django_db_views/fixtures.py new file mode 100644 index 0000000..80a58fa --- /dev/null +++ b/django_db_views/fixtures.py @@ -0,0 +1,58 @@ +from django.apps import apps +from django.db import connection + +from django_db_views.autodetector import ViewMigrationAutoDetector + +try: + import pytest +except ImportError: + raise Exception("fixtures are available only for pytest.") + + +@pytest.fixture(scope="session") +def django_db_views_setup( + request, + django_db_blocker, + django_db_use_migrations: bool, + django_db_keepdb: bool, +) -> None: + def no_migrations_tear_up() -> None: + view_models = ViewMigrationAutoDetector.get_current_view_models() + with django_db_blocker.unblock(), connection.schema_editor() as schema_editor: + engine = schema_editor.connection.settings_dict["ENGINE"] + for view_model in view_models: + view_definition = ( + ViewMigrationAutoDetector.get_view_definition_from_model( + view_model + )[engine] + ) + forward_migration = ( + ViewMigrationAutoDetector.get_forward_migration_class(view_model)( + view_definition.strip(";"), + view_model._meta.db_table, + engine=engine, + ) + ) + # run migration + forward_migration(apps, schema_editor) + + def no_migrations_teardown() -> None: + view_models = ViewMigrationAutoDetector.get_current_view_models() + with django_db_blocker.unblock(), connection.schema_editor() as schema_editor: + engine = schema_editor.connection.settings_dict["ENGINE"] + for view_model in view_models: + backward_migration = ( + ViewMigrationAutoDetector.get_backward_migration_class(view_model)( + "", + view_model._meta.db_table, + engine=engine, + ) + ) + # run migration + backward_migration(apps, schema_editor) + + if not django_db_use_migrations: + no_migrations_tear_up() + + if not django_db_keepdb: + request.addfinalizer(no_migrations_teardown) diff --git a/pyproject.toml b/pyproject.toml index b10c47c..b34be0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,9 @@ include = [ [tool.pytest.ini_options] addopts = "--migrations --create-db" +env = [ + "PYTHONDONTWRITEBYTECODE = 1" +] [tool.tox] legacy_tox_ini = """