From b3cb572c04b170127ea8f69de9bae54a331b971b Mon Sep 17 00:00:00 2001 From: Keita Ichihashi Date: Sun, 10 Mar 2024 19:45:23 +0900 Subject: [PATCH 1/4] Add support for composite unique constraints --- .../apps/migrations/auto/diffable_table.py | 1 + .../apps/migrations/auto/migration_manager.py | 73 ++++++++++++++----- piccolo/apps/migrations/auto/operations.py | 1 + piccolo/apps/migrations/auto/schema_differ.py | 30 +++++++- piccolo/columns/constraints.py | 41 +++++++++++ piccolo/query/methods/alter.py | 27 +++++++ piccolo/table.py | 8 +- .../migrations/auto/test_migration_manager.py | 1 + .../migrations/auto/test_schema_differ.py | 6 +- .../migrations/auto/test_schema_snapshot.py | 6 +- 10 files changed, 169 insertions(+), 25 deletions(-) create mode 100644 piccolo/columns/constraints.py diff --git a/piccolo/apps/migrations/auto/diffable_table.py b/piccolo/apps/migrations/auto/diffable_table.py index aa609f041..73df73e0c 100644 --- a/piccolo/apps/migrations/auto/diffable_table.py +++ b/piccolo/apps/migrations/auto/diffable_table.py @@ -146,6 +146,7 @@ def __sub__(self, value: DiffableTable) -> TableDelta: db_column_name=i.column._meta.db_column_name, tablename=value.tablename, schema=self.schema, + column_class=i.column.__class__, ) for i in sorted( {ColumnComparison(column=column) for column in value.columns} diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 772ec3edc..fb9389358 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -16,9 +16,11 @@ from piccolo.apps.migrations.auto.serialisation import deserialise_params from piccolo.columns import Column, column_types from piccolo.columns.column_types import ForeignKey, Serial +from piccolo.columns.constraints import UniqueConstraint from piccolo.engine import engine_finder from piccolo.query import Query from piccolo.query.base import DDL +from piccolo.query.methods.alter import AddUniqueConstraint, DropConstraint from piccolo.schema import SchemaDDLBase from piccolo.table import Table, create_table_class, sort_table_classes from piccolo.utils.warnings import colored_warning @@ -159,6 +161,12 @@ class MigrationManager: alter_columns: AlterColumnCollection = field( default_factory=AlterColumnCollection ) + add_unique_constraints: t.List[AddUniqueConstraint] = field( + default_factory=list + ) + drop_unique_constraints: t.List[DropConstraint] = field( + default_factory=list + ) raw: t.List[t.Union[t.Callable, AsyncFunction]] = field( default_factory=list ) @@ -258,25 +266,38 @@ def add_column( if column_class is None: raise ValueError("Unrecognised column type") - cleaned_params = deserialise_params(params=params) - column = column_class(**cleaned_params) + if column_class is UniqueConstraint: + column = column_class(**params) + else: + cleaned_params = deserialise_params(params=params) + column = column_class(**cleaned_params) + column._meta.name = column_name column._meta.db_column_name = db_column_name - self.add_columns.append( - AddColumnClass( - column=column, - tablename=tablename, - table_class_name=table_class_name, - schema=schema, + if isinstance(column_class, UniqueConstraint): + self.add_unique_constraints.append( + AddUniqueConstraint( + constraint_name=column_name, + columns=params.get("unique_columns"), # type: ignore + ) + ) + else: + self.add_columns.append( + AddColumnClass( + column=column, + tablename=tablename, + table_class_name=table_class_name, + schema=schema, + ) ) - ) def drop_column( self, table_class_name: str, tablename: str, column_name: str, + column_class: t.Type[Column], db_column_name: t.Optional[str] = None, schema: t.Optional[str] = None, ): @@ -284,6 +305,7 @@ def drop_column( DropColumn( table_class_name=table_class_name, column_name=column_name, + column_class=column_class, db_column_name=db_column_name or column_name, tablename=tablename, schema=schema, @@ -664,9 +686,16 @@ async def _run_drop_columns(self, backwards: bool = False): ) for column in columns: - await self._run_query( - _Table.alter().drop_column(column=column.column_name) - ) + if column.column_class == UniqueConstraint: + await _Table.alter().drop_constraint( + constraint_name=column.column_name + ) + else: + await self._run_query( + _Table.alter().drop_column( + column=column.column_name + ) + ) async def _run_rename_tables(self, backwards: bool = False): for rename_table in self.rename_tables: @@ -862,15 +891,23 @@ async def _run_add_columns(self, backwards: bool = False): add_column.column._meta.name ) - await self._run_query( - _Table.alter().add_column( - name=column._meta.name, column=column + if isinstance(column, UniqueConstraint): + await self._run_query( + _Table.alter().add_unique_constraint( + constraint_name=column._meta.db_column_name, + columns=column.unique_columns, + ) ) - ) - if add_column.column._meta.index: + else: await self._run_query( - _Table.create_index([add_column.column]) + _Table.alter().add_column( + name=column._meta.name, column=column + ) ) + if add_column.column._meta.index: + await self._run_query( + _Table.create_index([add_column.column]) + ) async def _run_change_table_schema(self, backwards: bool = False): from piccolo.schema import SchemaManager diff --git a/piccolo/apps/migrations/auto/operations.py b/piccolo/apps/migrations/auto/operations.py index 0676bdbd4..9e6138e38 100644 --- a/piccolo/apps/migrations/auto/operations.py +++ b/piccolo/apps/migrations/auto/operations.py @@ -49,6 +49,7 @@ class AlterColumn: class DropColumn: table_class_name: str column_name: str + column_class: t.Type[Column] db_column_name: str tablename: str schema: t.Optional[str] = None diff --git a/piccolo/apps/migrations/auto/schema_differ.py b/piccolo/apps/migrations/auto/schema_differ.py index 1d095b938..8230db3cf 100644 --- a/piccolo/apps/migrations/auto/schema_differ.py +++ b/piccolo/apps/migrations/auto/schema_differ.py @@ -21,6 +21,7 @@ UniqueGlobalNames, serialise_params, ) +from piccolo.columns.constraints import UniqueConstraint from piccolo.utils.printing import get_fixed_length_string @@ -270,9 +271,13 @@ def check_renamed_columns(self) -> RenameColumnCollection: used_drop_column_names: t.List[str] = [] for add_column in delta.add_columns: + if add_column.column_class == UniqueConstraint: + continue for drop_column in delta.drop_columns: if drop_column.column_name in used_drop_column_names: continue + if drop_column.column_class == UniqueConstraint: + continue user_response = self.auto_input or input( f"Did you rename the `{drop_column.db_column_name}` " # noqa: E501 @@ -508,6 +513,11 @@ def alter_columns(self) -> AlterStatements: ) if alter_column.old_column_class is not None: + if alter_column.old_column_class == UniqueConstraint: + print( + f"You cannot ALTER `{alter_column.column_name}` unique constraint! At first, delete it, then create the new one." # noqa: E501 + ) + continue extra_imports.append( Import( module=alter_column.old_column_class.__module__, @@ -538,6 +548,7 @@ def alter_columns(self) -> AlterStatements: @property def drop_columns(self) -> AlterStatements: response = [] + extra_imports: t.List[Import] = [] for table in self.schema: snapshot_table = self._get_snapshot_table(table.class_name) if snapshot_table: @@ -552,14 +563,29 @@ def drop_columns(self) -> AlterStatements: ): continue + column_class = column.column_class + extra_imports.append( + Import( + module=column_class.__module__, + target=column_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{column_class.__name__.upper()}", # noqa: E501 + None, + ), + ) + ) + schema_str = ( "None" if column.schema is None else f'"{column.schema}"' ) response.append( - f"manager.drop_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', db_column_name='{column.db_column_name}', schema={schema_str})" # noqa: E501 + f"manager.drop_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', db_column_name='{column.db_column_name}', schema={schema_str}, column_class={column_class.__name__})" # noqa: E501 ) - return AlterStatements(statements=response) + return AlterStatements( + statements=response, extra_imports=extra_imports + ) @property def add_columns(self) -> AlterStatements: diff --git a/piccolo/columns/constraints.py b/piccolo/columns/constraints.py new file mode 100644 index 000000000..81e1eed17 --- /dev/null +++ b/piccolo/columns/constraints.py @@ -0,0 +1,41 @@ +import typing as t + +from .base import Column, ColumnMeta + + +class Constraint(Column): + def __init__(self) -> None: + pass + + +class UniqueConstraint(Constraint): + """ + Used for applying unique constraint to multiple columns in the table. + + **Example** + + .. code-block:: python + + class FooTable(Table): + foo_field = Text() + bar_field = Text() + my_constraint_1 = UniqueConstraint(['foo_field', 'bar_field']) + """ + + def __init__(self, unique_columns: t.List[str]) -> None: + if len(unique_columns) < 2: + raise ValueError("unique_columns must contain at least 2 columns") + super().__init__() + self._meta = ColumnMeta() + self.unique_columns = unique_columns + self._meta.params.update({"unique_columns": self.unique_columns}) + + @property + def column_type(self): + return "CONSTRAINT" + + @property + def ddl(self) -> str: + unique_columns_string = ",".join(self.unique_columns) + query = f'{self.column_type} "{self._meta.db_column_name}" UNIQUE ({unique_columns_string})' # noqa: E501 + return query diff --git a/piccolo/query/methods/alter.py b/piccolo/query/methods/alter.py index 040b2f883..beaa07c1c 100644 --- a/piccolo/query/methods/alter.py +++ b/piccolo/query/methods/alter.py @@ -177,6 +177,19 @@ def ddl(self) -> str: return f'ALTER COLUMN "{self.column_name}" TYPE VARCHAR({self.length})' +@dataclass +class AddUniqueConstraint(AlterStatement): + __slots__ = ("constraint_name", "columns") + + constraint_name: str + columns: t.List[str] + + @property + def ddl(self) -> str: + columns_str: str = ",".join(self.columns) + return f"ADD CONSTRAINT {self.constraint_name} UNIQUE ({columns_str})" + + @dataclass class DropConstraint(AlterStatement): __slots__ = ("constraint_name",) @@ -275,6 +288,7 @@ class Alter(DDL): __slots__ = ( "_add_foreign_key_constraint", "_add", + "_add_unique_constraint", "_drop_constraint", "_drop_default", "_drop_table", @@ -294,6 +308,7 @@ def __init__(self, table: t.Type[Table], **kwargs): super().__init__(table, **kwargs) self._add_foreign_key_constraint: t.List[AddForeignKeyConstraint] = [] self._add: t.List[AddColumn] = [] + self._add_unique_constraint: t.List[AddUniqueConstraint] = [] self._drop_constraint: t.List[DropConstraint] = [] self._drop_default: t.List[DropDefault] = [] self._drop_table: t.Optional[DropTable] = None @@ -490,6 +505,16 @@ def _get_constraint_name(self, column: t.Union[str, ForeignKey]) -> str: tablename = self.table._meta.tablename return f"{tablename}_{column_name}_fk" + def add_unique_constraint( + self, constraint_name: str, columns: t.List[str] + ) -> Alter: + self._add_unique_constraint.append( + AddUniqueConstraint( + constraint_name=constraint_name, columns=columns + ) + ) + return self + def drop_constraint(self, constraint_name: str) -> Alter: self._drop_constraint.append( DropConstraint(constraint_name=constraint_name) @@ -590,6 +615,8 @@ def default_ddl(self) -> t.Sequence[str]: self._set_default, self._set_digits, self._set_schema, + self._add_unique_constraint, + self._drop_constraint, ) ] diff --git a/piccolo/table.py b/piccolo/table.py index d6735ac38..5bbd38bcc 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -18,6 +18,7 @@ Secret, Serial, ) +from piccolo.columns.constraints import UniqueConstraint from piccolo.columns.defaults.base import Default from piccolo.columns.indexes import IndexMethod from piccolo.columns.m2m import ( @@ -64,7 +65,6 @@ "reserved keyword. It should still work, but avoid if possible." ) - TABLE_REGISTRY: t.List[t.Type[Table]] = [] @@ -84,6 +84,7 @@ class TableMeta: primary_key: Column = field(default_factory=Column) json_columns: t.List[t.Union[JSON, JSONB]] = field(default_factory=list) secret_columns: t.List[Secret] = field(default_factory=list) + unique_constraints: t.List[UniqueConstraint] = field(default_factory=list) auto_update_columns: t.List[Column] = field(default_factory=list) tags: t.List[str] = field(default_factory=list) help_text: t.Optional[str] = None @@ -276,6 +277,7 @@ def __init_subclass__( auto_update_columns: t.List[Column] = [] primary_key: t.Optional[Column] = None m2m_relationships: t.List[M2M] = [] + unique_constraints: t.List[UniqueConstraint] = [] attribute_names = itertools.chain( *[i.__dict__.keys() for i in reversed(cls.__mro__)] @@ -328,6 +330,9 @@ def __init_subclass__( attribute._meta._table = cls m2m_relationships.append(attribute) + if isinstance(attribute, UniqueConstraint): + unique_constraints.append(attribute) + if not primary_key: primary_key = cls._create_serial_primary_key() setattr(cls, "id", primary_key) @@ -347,6 +352,7 @@ def __init_subclass__( json_columns=json_columns, secret_columns=secret_columns, auto_update_columns=auto_update_columns, + unique_constraints=unique_constraints, tags=tags, help_text=help_text, _db=db, diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index a1988a029..aa77fe370 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -551,6 +551,7 @@ def test_drop_column( table_class_name="Musician", tablename="musician", column_name="name", + column_class=Varchar, ) asyncio.run(manager_2.run()) diff --git a/tests/apps/migrations/auto/test_schema_differ.py b/tests/apps/migrations/auto/test_schema_differ.py index 9cf6d26f2..cb193a3ef 100644 --- a/tests/apps/migrations/auto/test_schema_differ.py +++ b/tests/apps/migrations/auto/test_schema_differ.py @@ -199,7 +199,7 @@ def test_drop_column(self) -> None: self.assertTrue(len(schema_differ.drop_columns.statements) == 1) self.assertEqual( schema_differ.drop_columns.statements[0], - "manager.drop_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', schema=None)", # noqa: E501 + "manager.drop_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', schema=None, column_class=Varchar)", # noqa: E501 ) def test_rename_column(self) -> None: @@ -254,7 +254,7 @@ def test_rename_column(self) -> None: self.assertEqual( schema_differ.drop_columns.statements, [ - "manager.drop_column(table_class_name='Band', tablename='band', column_name='title', db_column_name='title', schema=None)" # noqa: E501 + "manager.drop_column(table_class_name='Band', tablename='band', column_name='title', db_column_name='title', schema=None, column_class=Varchar)" # noqa: E501 ], ) self.assertTrue(schema_differ.rename_columns.statements == []) @@ -396,7 +396,7 @@ def mock_input(value: str): self.assertEqual( schema_differ.drop_columns.statements, [ - "manager.drop_column(table_class_name='Band', tablename='band', column_name='b1', db_column_name='b1', schema=None)" # noqa: E501 + "manager.drop_column(table_class_name='Band', tablename='band', column_name='b1', db_column_name='b1', schema=None, column_class=Varchar)" # noqa: E501 ], ) self.assertEqual( diff --git a/tests/apps/migrations/auto/test_schema_snapshot.py b/tests/apps/migrations/auto/test_schema_snapshot.py index 834551f8a..1d46dc94e 100644 --- a/tests/apps/migrations/auto/test_schema_snapshot.py +++ b/tests/apps/migrations/auto/test_schema_snapshot.py @@ -1,6 +1,7 @@ from unittest import TestCase from piccolo.apps.migrations.auto import MigrationManager, SchemaSnapshot +from piccolo.columns import Varchar class TestSchemaSnaphot(TestCase): @@ -140,7 +141,10 @@ def test_drop_column(self): manager_2 = MigrationManager() manager_2.drop_column( - table_class_name="Manager", tablename="manager", column_name="name" + table_class_name="Manager", + tablename="manager", + column_name="name", + column_class=Varchar, ) schema_snapshot = SchemaSnapshot(managers=[manager_1, manager_2]) From 53193024278a240c71facda8de179381bce54bc7 Mon Sep 17 00:00:00 2001 From: Keita Ichihashi Date: Sat, 16 Mar 2024 22:49:33 +0900 Subject: [PATCH 2/4] Fix backwards migration and add tests --- .../apps/migrations/auto/migration_manager.py | 63 ++--- .../migrations/auto/test_migration_manager.py | 98 ++++++++ .../migrations/auto/test_schema_differ.py | 220 ++++++++++++++++++ 3 files changed, 352 insertions(+), 29 deletions(-) diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index fb9389358..a3a280073 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -266,31 +266,19 @@ def add_column( if column_class is None: raise ValueError("Unrecognised column type") - if column_class is UniqueConstraint: - column = column_class(**params) - else: - cleaned_params = deserialise_params(params=params) - column = column_class(**cleaned_params) - + cleaned_params = deserialise_params(params=params) + column = column_class(**cleaned_params) column._meta.name = column_name column._meta.db_column_name = db_column_name - if isinstance(column_class, UniqueConstraint): - self.add_unique_constraints.append( - AddUniqueConstraint( - constraint_name=column_name, - columns=params.get("unique_columns"), # type: ignore - ) - ) - else: - self.add_columns.append( - AddColumnClass( - column=column, - tablename=tablename, - table_class_name=table_class_name, - schema=schema, - ) + self.add_columns.append( + AddColumnClass( + column=column, + tablename=tablename, + table_class_name=table_class_name, + schema=schema, ) + ) def drop_column( self, @@ -663,11 +651,21 @@ async def _run_drop_columns(self, backwards: bool = False): column_to_restore = _Table._meta.get_column_by_name( drop_column.column_name ) - await self._run_query( - _Table.alter().add_column( - name=drop_column.column_name, column=column_to_restore + + if isinstance(column_to_restore, UniqueConstraint): + await self._run_query( + _Table.alter().add_unique_constraint( + constraint_name=column_to_restore._meta.db_column_name, # noqa: E501 + columns=column_to_restore.unique_columns, + ) + ) + else: + await self._run_query( + _Table.alter().add_column( + name=drop_column.column_name, + column=column_to_restore, + ) ) - ) else: for table_class_name in self.drop_columns.table_class_names: columns = self.drop_columns.for_table_class_name( @@ -686,7 +684,7 @@ async def _run_drop_columns(self, backwards: bool = False): ) for column in columns: - if column.column_class == UniqueConstraint: + if column.column_class is UniqueConstraint: await _Table.alter().drop_constraint( constraint_name=column.column_name ) @@ -813,9 +811,16 @@ async def _run_add_columns(self, backwards: bool = False): }, ) - await self._run_query( - _Table.alter().drop_column(add_column.column) - ) + if isinstance(add_column.column, UniqueConstraint): + await self._run_query( + _Table.alter().drop_constraint( + constraint_name=add_column.column._meta.db_column_name, # noqa: E501 + ) + ) + else: + await self._run_query( + _Table.alter().drop_column(add_column.column) + ) else: for table_class_name in self.add_columns.table_class_names: if table_class_name in [i.class_name for i in self.add_tables]: diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index aa77fe370..128a64714 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -10,6 +10,7 @@ from piccolo.columns import Text, Varchar from piccolo.columns.base import OnDelete, OnUpdate from piccolo.columns.column_types import ForeignKey +from piccolo.columns.constraints import UniqueConstraint from piccolo.conf.apps import AppConfig from piccolo.table import Table, sort_table_classes from piccolo.utils.lazy_loader import LazyLoader @@ -336,6 +337,103 @@ def test_add_column(self) -> None: if engine_is("cockroach"): self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) + @engines_only("postgres", "cockroach") + def test_add_unique_constraint(self): + """ + Test adding a unique constraint to a MigrationManager. + """ + manager = MigrationManager() + manager.add_table(class_name="Musician", tablename="musician") + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="musician_unique", + column_class=UniqueConstraint, + column_class_name="UniqueConstraint", + params={ + "unique_columns": ["name", "label"], + }, + schema=None, + ) + asyncio.run(manager.run()) + + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager.run(backwards=True)) + + @engines_only("postgres") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_drop_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test dropping a unique constraint with a MigrationManager. + Cockroach DB doesn't support dropping unique constraints with ALTER TABLE DROP CONSTRAINT. + https://github.com/cockroachdb/cockroach/issues/42840 + """ # noqa: E501 + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="musician_unique", + column_class=UniqueConstraint, + column_class_name="UniqueConstraint", + params={ + "unique_columns": ["name", "label"], + }, + ) + asyncio.run(manager_1.run()) + + manager_2 = MigrationManager() + manager_2.drop_column( + table_class_name="Musician", + tablename="musician", + column_name="musician_unique", + column_class=UniqueConstraint, + ) + asyncio.run(manager_2.run()) + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + asyncio.run(manager_1.run(backwards=True)) + @engines_only("postgres", "cockroach") def test_add_column_with_index(self): """ diff --git a/tests/apps/migrations/auto/test_schema_differ.py b/tests/apps/migrations/auto/test_schema_differ.py index cb193a3ef..ae99118c8 100644 --- a/tests/apps/migrations/auto/test_schema_differ.py +++ b/tests/apps/migrations/auto/test_schema_differ.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from io import StringIO from unittest import TestCase from unittest.mock import MagicMock, call, patch @@ -13,6 +14,7 @@ SchemaDiffer, ) from piccolo.columns.column_types import Numeric, Varchar +from piccolo.columns.constraints import UniqueConstraint class TestSchemaDiffer(TestCase): @@ -488,6 +490,224 @@ def test_db_column_name(self) -> None: "manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', db_column_name='custom', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric, schema=None)", # noqa ) + def test_add_unique_constraint(self) -> None: + """ + Test adding a unique constraint to an existing table. + """ + a_column = Varchar() + a_column._meta.name = "a" + + b_column = Varchar() + b_column._meta.name = "b" + + unique_constraint = UniqueConstraint(["a", "b"]) + unique_constraint._meta.name = "band_unique" + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, unique_constraint], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.add_columns.statements) == 1) + self.assertEqual( + schema_differ.add_columns.statements[0], + "manager.add_column(table_class_name='Band', tablename='band', column_name='band_unique', db_column_name='band_unique', column_class_name='UniqueConstraint', column_class=UniqueConstraint, params={'unique_columns': ['a', 'b']}, schema=None)", # noqa: E501 + ) + + def test_drop_unique_constraint(self) -> None: + """ + Test dropping a unique constraint from an existing table. + """ + a_column = Varchar() + a_column._meta.name = "a" + + b_column = Varchar() + b_column._meta.name = "b" + + unique_constraint = UniqueConstraint(["a", "b"]) + unique_constraint._meta.name = "band_unique" + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, unique_constraint], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.drop_columns.statements) == 1) + self.assertEqual( + schema_differ.drop_columns.statements[0], + "manager.drop_column(table_class_name='Band', tablename='band', column_name='band_unique', db_column_name='band_unique', schema=None, column_class=UniqueConstraint)", # noqa: E501 + ) + + def test_no_diff_alter_unique_constraint(self) -> None: + """ + Make sure returning no diff when attempting + to alter a unique constraint. + """ + a_column = Varchar() + a_column._meta.name = "a" + + b_column = Varchar() + b_column._meta.name = "b" + + c_column = Varchar() + c_column._meta.name = "c" + + unique_constraint_1 = UniqueConstraint(["a", "b"]) + unique_constraint_1._meta.name = "band_unique" + + unique_constraint_2 = UniqueConstraint(["a", "b", "c"]) + unique_constraint_2._meta.name = "band_unique" + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, c_column, unique_constraint_2], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, c_column, unique_constraint_1], + ) + ] + + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.add_columns.statements) == 0) + self.assertTrue(len(schema_differ.drop_columns.statements) == 0) + self.assertTrue(len(schema_differ.alter_columns.statements) == 0) + self.assertEqual( + mock_stdout.getvalue().strip(), + "You cannot ALTER `band_unique` unique constraint! At first, delete it, then create the new one.", # noqa: E501 + ) + + def test_add_unique_constraint_skip_rename(self) -> None: + """ + Make sure the column renaming is skipped + when adding a unique constraint. + """ + a_column = Varchar() + a_column._meta.name = "a" + + b_column = Varchar() + b_column._meta.name = "b" + + c_column = Varchar() + c_column._meta.name = "c" + + unique_constraint = UniqueConstraint(["a", "b"]) + unique_constraint._meta.name = "band_unique" + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, unique_constraint], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, c_column], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.drop_columns.statements) == 1) + self.assertEqual( + schema_differ.drop_columns.statements[0], + "manager.drop_column(table_class_name='Band', tablename='band', column_name='c', db_column_name='c', schema=None, column_class=Varchar)", # noqa: E501 + ) + self.assertTrue(len(schema_differ.add_columns.statements) == 1) + self.assertEqual( + schema_differ.add_columns.statements[0], + "manager.add_column(table_class_name='Band', tablename='band', column_name='band_unique', db_column_name='band_unique', column_class_name='UniqueConstraint', column_class=UniqueConstraint, params={'unique_columns': ['a', 'b']}, schema=None)", # noqa: E501 + ) + + def test_drop_unique_constraint_skip_rename(self) -> None: + """ + Make sure the column renaming is skipped + when dropping a unique constraint. + """ + a_column = Varchar() + a_column._meta.name = "a" + + b_column = Varchar() + b_column._meta.name = "b" + + c_column = Varchar() + c_column._meta.name = "c" + + unique_constraint = UniqueConstraint(["a", "b"]) + unique_constraint._meta.name = "band_unique" + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, c_column], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[a_column, b_column, unique_constraint], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.drop_columns.statements) == 1) + self.assertEqual( + schema_differ.drop_columns.statements[0], + "manager.drop_column(table_class_name='Band', tablename='band', column_name='band_unique', db_column_name='band_unique', schema=None, column_class=UniqueConstraint)", # noqa: E501 + ) + self.assertTrue(len(schema_differ.add_columns.statements) == 1) + self.assertEqual( + schema_differ.add_columns.statements[0], + "manager.add_column(table_class_name='Band', tablename='band', column_name='c', db_column_name='c', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)", # noqa: E501 + ) + def test_alter_default(self): pass From 3de933787994677e41339cde104b8f7669cbaec4 Mon Sep 17 00:00:00 2001 From: Keita Ichihashi Date: Sun, 17 Mar 2024 16:30:27 +0900 Subject: [PATCH 3/4] Update --- piccolo/apps/migrations/auto/schema_differ.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/piccolo/apps/migrations/auto/schema_differ.py b/piccolo/apps/migrations/auto/schema_differ.py index 8230db3cf..09599efee 100644 --- a/piccolo/apps/migrations/auto/schema_differ.py +++ b/piccolo/apps/migrations/auto/schema_differ.py @@ -268,15 +268,19 @@ def check_renamed_columns(self) -> RenameColumnCollection: # We track which dropped columns have already been identified by # the user as renames, so we don't ask them if another column # was also renamed from it. + + # When adding or removing a unique constraint, + # we don't ask and rename it. used_drop_column_names: t.List[str] = [] for add_column in delta.add_columns: - if add_column.column_class == UniqueConstraint: + if add_column.column_class is UniqueConstraint: continue for drop_column in delta.drop_columns: - if drop_column.column_name in used_drop_column_names: - continue - if drop_column.column_class == UniqueConstraint: + if ( + drop_column.column_name in used_drop_column_names + or drop_column.column_class is UniqueConstraint + ): continue user_response = self.auto_input or input( @@ -513,7 +517,7 @@ def alter_columns(self) -> AlterStatements: ) if alter_column.old_column_class is not None: - if alter_column.old_column_class == UniqueConstraint: + if alter_column.old_column_class is UniqueConstraint: print( f"You cannot ALTER `{alter_column.column_name}` unique constraint! At first, delete it, then create the new one." # noqa: E501 ) From 0a9ea5b2f9ab842923fc9430e2d2986fb423220c Mon Sep 17 00:00:00 2001 From: Keita Ichihashi Date: Fri, 22 Mar 2024 23:00:06 +0900 Subject: [PATCH 4/4] Fix tests --- .../migrations/auto/test_migration_manager.py | 60 ++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index 128a64714..459b68d48 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -338,9 +338,9 @@ def test_add_column(self) -> None: self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) @engines_only("postgres", "cockroach") - def test_add_unique_constraint(self): + def test_add_table_with_unique_constraint(self): """ - Test adding a unique constraint to a MigrationManager. + Test adding a table with a unique constraint to a MigrationManager. """ manager = MigrationManager() manager.add_table(class_name="Musician", tablename="musician") @@ -376,6 +376,62 @@ def test_add_unique_constraint(self): # Reverse asyncio.run(manager.run(backwards=True)) + @engines_only("postgres") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_add_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test adding a unique constraint to a MigrationManager. + Cockroach DB doesn't support dropping unique constraints with ALTER TABLE DROP CONSTRAINT. + https://github.com/cockroachdb/cockroach/issues/42840 + """ # noqa: E501 + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + asyncio.run(manager_1.run()) + + manager_2 = MigrationManager() + manager_2.add_column( + table_class_name="Musician", + tablename="musician", + column_name="musician_unique", + column_class=UniqueConstraint, + column_class_name="UniqueConstraint", + params={ + "unique_columns": ["name", "label"], + }, + schema=None, + ) + asyncio.run(manager_2.run()) + + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + asyncio.run(manager_1.run(backwards=True)) + @engines_only("postgres") @patch.object( BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock