Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support composite unique constraints in auto migration #957

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions piccolo/apps/migrations/auto/diffable_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
db_column_name=i.column._meta.db_column_name,
tablename=value.tablename,
schema=self.schema,
column_class=i.column.__class__,
)
for i in sorted(
{ColumnComparison(column=column) for column in value.columns}
Expand Down
74 changes: 58 additions & 16 deletions piccolo/apps/migrations/auto/migration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from piccolo.apps.migrations.auto.serialisation import deserialise_params
from piccolo.columns import Column, column_types
from piccolo.columns.column_types import ForeignKey, Serial
from piccolo.columns.constraints import UniqueConstraint
from piccolo.engine import engine_finder
from piccolo.query import Query
from piccolo.query.base import DDL
from piccolo.query.methods.alter import AddUniqueConstraint, DropConstraint
from piccolo.schema import SchemaDDLBase
from piccolo.table import Table, create_table_class, sort_table_classes
from piccolo.utils.warnings import colored_warning
Expand Down Expand Up @@ -159,6 +161,12 @@ class MigrationManager:
alter_columns: AlterColumnCollection = field(
default_factory=AlterColumnCollection
)
add_unique_constraints: t.List[AddUniqueConstraint] = field(
default_factory=list
)
drop_unique_constraints: t.List[DropConstraint] = field(
default_factory=list
)
raw: t.List[t.Union[t.Callable, AsyncFunction]] = field(
default_factory=list
)
Expand Down Expand Up @@ -277,13 +285,15 @@ def drop_column(
table_class_name: str,
tablename: str,
column_name: str,
column_class: t.Type[Column],
db_column_name: t.Optional[str] = None,
schema: t.Optional[str] = None,
):
self.drop_columns.append(
DropColumn(
table_class_name=table_class_name,
column_name=column_name,
column_class=column_class,
db_column_name=db_column_name or column_name,
tablename=tablename,
schema=schema,
Expand Down Expand Up @@ -641,11 +651,21 @@ async def _run_drop_columns(self, backwards: bool = False):
column_to_restore = _Table._meta.get_column_by_name(
drop_column.column_name
)
await self._run_query(
_Table.alter().add_column(
name=drop_column.column_name, column=column_to_restore

if isinstance(column_to_restore, UniqueConstraint):
await self._run_query(
_Table.alter().add_unique_constraint(
constraint_name=column_to_restore._meta.db_column_name, # noqa: E501
columns=column_to_restore.unique_columns,
)
)
else:
await self._run_query(
_Table.alter().add_column(
name=drop_column.column_name,
column=column_to_restore,
)
)
)
else:
for table_class_name in self.drop_columns.table_class_names:
columns = self.drop_columns.for_table_class_name(
Expand All @@ -664,9 +684,16 @@ async def _run_drop_columns(self, backwards: bool = False):
)

for column in columns:
await self._run_query(
_Table.alter().drop_column(column=column.column_name)
)
if column.column_class is UniqueConstraint:
await _Table.alter().drop_constraint(
constraint_name=column.column_name
)
else:
await self._run_query(
_Table.alter().drop_column(
column=column.column_name
)
)

async def _run_rename_tables(self, backwards: bool = False):
for rename_table in self.rename_tables:
Expand Down Expand Up @@ -784,9 +811,16 @@ async def _run_add_columns(self, backwards: bool = False):
},
)

await self._run_query(
_Table.alter().drop_column(add_column.column)
)
if isinstance(add_column.column, UniqueConstraint):
await self._run_query(
_Table.alter().drop_constraint(
constraint_name=add_column.column._meta.db_column_name, # noqa: E501
)
)
else:
await self._run_query(
_Table.alter().drop_column(add_column.column)
)
else:
for table_class_name in self.add_columns.table_class_names:
if table_class_name in [i.class_name for i in self.add_tables]:
Expand Down Expand Up @@ -862,15 +896,23 @@ async def _run_add_columns(self, backwards: bool = False):
add_column.column._meta.name
)

await self._run_query(
_Table.alter().add_column(
name=column._meta.name, column=column
if isinstance(column, UniqueConstraint):
await self._run_query(
_Table.alter().add_unique_constraint(
constraint_name=column._meta.db_column_name,
columns=column.unique_columns,
)
)
)
if add_column.column._meta.index:
else:
await self._run_query(
_Table.create_index([add_column.column])
_Table.alter().add_column(
name=column._meta.name, column=column
)
)
if add_column.column._meta.index:
await self._run_query(
_Table.create_index([add_column.column])
)

async def _run_change_table_schema(self, backwards: bool = False):
from piccolo.schema import SchemaManager
Expand Down
1 change: 1 addition & 0 deletions piccolo/apps/migrations/auto/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class AlterColumn:
class DropColumn:
table_class_name: str
column_name: str
column_class: t.Type[Column]
db_column_name: str
tablename: str
schema: t.Optional[str] = None
Expand Down
36 changes: 33 additions & 3 deletions piccolo/apps/migrations/auto/schema_differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
UniqueGlobalNames,
serialise_params,
)
from piccolo.columns.constraints import UniqueConstraint
from piccolo.utils.printing import get_fixed_length_string


Expand Down Expand Up @@ -267,11 +268,19 @@ def check_renamed_columns(self) -> RenameColumnCollection:
# We track which dropped columns have already been identified by
# the user as renames, so we don't ask them if another column
# was also renamed from it.

# When adding or removing a unique constraint,
# we don't ask and rename it.
used_drop_column_names: t.List[str] = []

for add_column in delta.add_columns:
if add_column.column_class is UniqueConstraint:
continue
for drop_column in delta.drop_columns:
if drop_column.column_name in used_drop_column_names:
if (
drop_column.column_name in used_drop_column_names
or drop_column.column_class is UniqueConstraint
):
continue

user_response = self.auto_input or input(
Expand Down Expand Up @@ -508,6 +517,11 @@ def alter_columns(self) -> AlterStatements:
)

if alter_column.old_column_class is not None:
if alter_column.old_column_class is UniqueConstraint:
print(
f"You cannot ALTER `{alter_column.column_name}` unique constraint! At first, delete it, then create the new one." # noqa: E501
)
continue
extra_imports.append(
Import(
module=alter_column.old_column_class.__module__,
Expand Down Expand Up @@ -538,6 +552,7 @@ def alter_columns(self) -> AlterStatements:
@property
def drop_columns(self) -> AlterStatements:
response = []
extra_imports: t.List[Import] = []
for table in self.schema:
snapshot_table = self._get_snapshot_table(table.class_name)
if snapshot_table:
Expand All @@ -552,14 +567,29 @@ def drop_columns(self) -> AlterStatements:
):
continue

column_class = column.column_class
extra_imports.append(
Import(
module=column_class.__module__,
target=column_class.__name__,
expect_conflict_with_global_name=getattr(
UniqueGlobalNames,
f"COLUMN_{column_class.__name__.upper()}", # noqa: E501
None,
),
)
)

schema_str = (
"None" if column.schema is None else f'"{column.schema}"'
)

response.append(
f"manager.drop_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', db_column_name='{column.db_column_name}', schema={schema_str})" # noqa: E501
f"manager.drop_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', db_column_name='{column.db_column_name}', schema={schema_str}, column_class={column_class.__name__})" # noqa: E501
)
return AlterStatements(statements=response)
return AlterStatements(
statements=response, extra_imports=extra_imports
)

@property
def add_columns(self) -> AlterStatements:
Expand Down
41 changes: 41 additions & 0 deletions piccolo/columns/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import typing as t

from .base import Column, ColumnMeta


class Constraint(Column):
def __init__(self) -> None:
pass


class UniqueConstraint(Constraint):
"""
Used for applying unique constraint to multiple columns in the table.

**Example**

.. code-block:: python

class FooTable(Table):
foo_field = Text()
bar_field = Text()
my_constraint_1 = UniqueConstraint(['foo_field', 'bar_field'])
"""

def __init__(self, unique_columns: t.List[str]) -> None:
if len(unique_columns) < 2:
raise ValueError("unique_columns must contain at least 2 columns")
super().__init__()
self._meta = ColumnMeta()
self.unique_columns = unique_columns
self._meta.params.update({"unique_columns": self.unique_columns})

@property
def column_type(self):
return "CONSTRAINT"

@property
def ddl(self) -> str:
unique_columns_string = ",".join(self.unique_columns)
query = f'{self.column_type} "{self._meta.db_column_name}" UNIQUE ({unique_columns_string})' # noqa: E501
return query
27 changes: 27 additions & 0 deletions piccolo/query/methods/alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,19 @@ def ddl(self) -> str:
return f'ALTER COLUMN "{self.column_name}" TYPE VARCHAR({self.length})'


@dataclass
class AddUniqueConstraint(AlterStatement):
__slots__ = ("constraint_name", "columns")

constraint_name: str
columns: t.List[str]

@property
def ddl(self) -> str:
columns_str: str = ",".join(self.columns)
return f"ADD CONSTRAINT {self.constraint_name} UNIQUE ({columns_str})"


@dataclass
class DropConstraint(AlterStatement):
__slots__ = ("constraint_name",)
Expand Down Expand Up @@ -275,6 +288,7 @@ class Alter(DDL):
__slots__ = (
"_add_foreign_key_constraint",
"_add",
"_add_unique_constraint",
"_drop_constraint",
"_drop_default",
"_drop_table",
Expand All @@ -294,6 +308,7 @@ def __init__(self, table: t.Type[Table], **kwargs):
super().__init__(table, **kwargs)
self._add_foreign_key_constraint: t.List[AddForeignKeyConstraint] = []
self._add: t.List[AddColumn] = []
self._add_unique_constraint: t.List[AddUniqueConstraint] = []
self._drop_constraint: t.List[DropConstraint] = []
self._drop_default: t.List[DropDefault] = []
self._drop_table: t.Optional[DropTable] = None
Expand Down Expand Up @@ -490,6 +505,16 @@ def _get_constraint_name(self, column: t.Union[str, ForeignKey]) -> str:
tablename = self.table._meta.tablename
return f"{tablename}_{column_name}_fk"

def add_unique_constraint(
self, constraint_name: str, columns: t.List[str]
) -> Alter:
self._add_unique_constraint.append(
AddUniqueConstraint(
constraint_name=constraint_name, columns=columns
)
)
return self

def drop_constraint(self, constraint_name: str) -> Alter:
self._drop_constraint.append(
DropConstraint(constraint_name=constraint_name)
Expand Down Expand Up @@ -590,6 +615,8 @@ def default_ddl(self) -> t.Sequence[str]:
self._set_default,
self._set_digits,
self._set_schema,
self._add_unique_constraint,
self._drop_constraint,
)
]

Expand Down
8 changes: 7 additions & 1 deletion piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Secret,
Serial,
)
from piccolo.columns.constraints import UniqueConstraint
from piccolo.columns.defaults.base import Default
from piccolo.columns.indexes import IndexMethod
from piccolo.columns.m2m import (
Expand Down Expand Up @@ -64,7 +65,6 @@
"reserved keyword. It should still work, but avoid if possible."
)


TABLE_REGISTRY: t.List[t.Type[Table]] = []


Expand All @@ -84,6 +84,7 @@ class TableMeta:
primary_key: Column = field(default_factory=Column)
json_columns: t.List[t.Union[JSON, JSONB]] = field(default_factory=list)
secret_columns: t.List[Secret] = field(default_factory=list)
unique_constraints: t.List[UniqueConstraint] = field(default_factory=list)
auto_update_columns: t.List[Column] = field(default_factory=list)
tags: t.List[str] = field(default_factory=list)
help_text: t.Optional[str] = None
Expand Down Expand Up @@ -276,6 +277,7 @@ def __init_subclass__(
auto_update_columns: t.List[Column] = []
primary_key: t.Optional[Column] = None
m2m_relationships: t.List[M2M] = []
unique_constraints: t.List[UniqueConstraint] = []

attribute_names = itertools.chain(
*[i.__dict__.keys() for i in reversed(cls.__mro__)]
Expand Down Expand Up @@ -328,6 +330,9 @@ def __init_subclass__(
attribute._meta._table = cls
m2m_relationships.append(attribute)

if isinstance(attribute, UniqueConstraint):
unique_constraints.append(attribute)

if not primary_key:
primary_key = cls._create_serial_primary_key()
setattr(cls, "id", primary_key)
Expand All @@ -347,6 +352,7 @@ def __init_subclass__(
json_columns=json_columns,
secret_columns=secret_columns,
auto_update_columns=auto_update_columns,
unique_constraints=unique_constraints,
tags=tags,
help_text=help_text,
_db=db,
Expand Down
Loading
Loading