Skip to content

Commit

Permalink
Rely on Field.db_default for field additions and removals.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Nov 5, 2023
1 parent 24847b0 commit f64189d
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 41 deletions.
22 changes: 13 additions & 9 deletions syzygy/autodetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from .constants import Stage
from .exceptions import AmbiguousStage
from .operations import (
AddField,
AlterField,
PostAddField,
PreRemoveField,
RenameField,
RenameModel,
get_pre_field_addition_operation,
get_pre_remove_field_operation,
)
from .plan import partition_operations

Expand Down Expand Up @@ -167,14 +167,16 @@ def _generate_added_field(self, app_label, model_name, field_name):
super()._generate_added_field(app_label, model_name, field_name)
old_add_field = self.generated_operations[app_label][-1]
field = old_add_field.field
if field.null and not field.has_default():
if (field.null and not field.has_default()) or getattr(
field, "db_default", NOT_PROVIDED
) is not NOT_PROVIDED:
return
# ... and immediately swap the added operation by an adjsuted one.
add_field = AddField(
# ... otherwise swap the added operation by an adjusted one.
add_field = get_pre_field_addition_operation(
old_add_field.model_name,
old_add_field.name,
field,
old_add_field.preserve_default,
old_add_field.field,
preserve_default=old_add_field.preserve_default,
)
add_field._auto_deps = old_add_field._auto_deps
self.generated_operations[app_label][-1] = add_field
Expand All @@ -198,7 +200,9 @@ def _generate_added_field(self, app_label, model_name, field_name):
def _generate_removed_field(self, app_label, model_name, field_name):
field = self.from_state.models[app_label, model_name].fields[field_name]
remove_default = field.default
if remove_default is NOT_PROVIDED and field.null:
if (remove_default is NOT_PROVIDED and field.null) or getattr(
field, "db_default", NOT_PROVIDED
) is not NOT_PROVIDED:
return super()._generate_removed_field(app_label, model_name, field_name)

if remove_default is NOT_PROVIDED:
Expand Down Expand Up @@ -232,7 +236,7 @@ def _generate_removed_field(self, app_label, model_name, field_name):
field.null = True
else:
field.default = remove_default
pre_remove_field = PreRemoveField(
pre_remove_field = get_pre_remove_field_operation(
model_name=model_name, name=field_name, field=field
)
self.add_operation(app_label, pre_remove_field)
Expand Down
3 changes: 3 additions & 0 deletions syzygy/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import django

field_db_default_supported = django.VERSION >= (5,)
51 changes: 51 additions & 0 deletions syzygy/operations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from contextlib import contextmanager

import django
from django.db.migrations import operations
from django.db.models.fields import NOT_PROVIDED
from django.utils.functional import cached_property

from .compat import field_db_default_supported
from .constants import Stage


Expand Down Expand Up @@ -122,6 +125,37 @@ def describe(self):
return "Set field %s of %s NULLable" % (self.name, self.model_name)


if field_db_default_supported:
# XXX: This allows for a more descriptive migration_name_fragment
# to be associated with instances of AlterField.
operations.AlterField.migration_name_fragment = cached_property(
operations.AlterField.migration_name_fragment.fget
)

def get_pre_remove_field_operation(model_name, name, field, **kwargs):
if field.db_default is not NOT_PROVIDED:
raise ValueError(
"Fields with a db_default don't require a pre-deployment operation."
)
field = field.clone()
if field.has_default():
field.db_default = field.get_default()
fragment = f"set_db_default_{model_name.lower()}_{name}"
description = f"Set database DEFAULT of field {name} on {model_name}"
else:
field.null = True
fragment = f"set_nullable_{model_name.lower()}_{name}"
description = f"Set field {name} of {model_name} NULLable"
operation = operations.AlterField(model_name, name, field, **kwargs)
operation.migration_name_fragment = fragment
operation.describe = lambda: description
return operation

PreRemoveField = get_pre_remove_field_operation
else:
get_pre_remove_field_operation = PreRemoveField


class AddField(operations.AddField):
"""
Subclass of `AddField` that preserves the database default on database
Expand Down Expand Up @@ -167,6 +201,23 @@ def database_forwards(self, app_label, schema_editor, from_state, to_state):
)


if field_db_default_supported:

def get_pre_field_addition_operation(model_name, name, field, **kwargs):
if field.db_default is not NOT_PROVIDED:
raise ValueError(
"Fields with a db_default don't require a pre-deployment operation."
)
field = field.clone()
field.db_default = field.get_default()
operation = operations.AddField(model_name, name, field, **kwargs)
return operation

AddField = get_pre_field_addition_operation
else:
get_pre_field_addition_operation = AddField


class PostAddField(operations.AlterField):
"""
Elidable operation that drops a previously preserved database default.
Expand Down
51 changes: 44 additions & 7 deletions tests/test_autodetector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional
from unittest import mock
from unittest import mock, skipUnless

import django
from django.core.management.color import color_style
from django.db import migrations, models
from django.db.migrations.questioner import (
Expand All @@ -12,6 +13,7 @@
from django.test.utils import captured_stderr, captured_stdin, captured_stdout

from syzygy.autodetector import MigrationAutodetector
from syzygy.compat import field_db_default_supported
from syzygy.constants import Stage
from syzygy.exceptions import AmbiguousStage
from syzygy.operations import (
Expand Down Expand Up @@ -53,23 +55,25 @@ def get_changes(
class AutodetectorTests(AutodetectorTestCase):
def _test_field_addition(self, field):
from_model = ModelState("tests", "Model", [])
to_model = ModelState(
"tests", "Model", [("field", models.IntegerField(default=42))]
)
to_model = ModelState("tests", "Model", [("field", field)])
changes = self.get_changes([from_model], [to_model])["tests"]
self.assertEqual(len(changes), 2)
self.assertEqual(get_migration_stage(changes[0]), Stage.PRE_DEPLOY)
self.assertEqual(changes[0].dependencies, [])
self.assertEqual(len(changes[0].operations), 1)
self.assertIsInstance(changes[0].operations[0], AddField)
pre_operation = changes[0].operations[0]
if field_db_default_supported:
self.assertIsInstance(pre_operation, migrations.AddField)
self.assertEqual(pre_operation.field.db_default, field.default)
else:
self.assertIsInstance(pre_operation, AddField)
self.assertEqual(get_migration_stage(changes[1]), Stage.POST_DEPLOY)
self.assertEqual(changes[1].dependencies, [("tests", "auto_1")])
self.assertEqual(len(changes[1].operations), 1)
self.assertIsInstance(changes[1].operations[0], PostAddField)

def test_field_addition(self):
fields = [
models.IntegerField(),
models.IntegerField(default=42),
models.IntegerField(null=True, default=42),
]
Expand All @@ -90,6 +94,19 @@ def test_nullable_field_addition(self):
self.assertEqual(len(changes), 1)
self.assertEqual(get_migration_stage(changes[0]), Stage.PRE_DEPLOY)

@skipUnless(field_db_default_supported, "Field.db_default is not supported")
def test_db_default_field_addition(self):
"""
No action required if the field already has a `db_default`
"""
from_model = ModelState("tests", "Model", [])
to_model = ModelState(
"tests", "Model", [("field", models.IntegerField(db_default=42))]
)
changes = self.get_changes([from_model], [to_model])["tests"]
self.assertEqual(len(changes), 1)
self.assertEqual(get_migration_stage(changes[0]), Stage.PRE_DEPLOY)

def _test_field_removal(self, field):
from_model = ModelState("tests", "Model", [("field", field)])
to_model = ModelState("tests", "Model", [])
Expand All @@ -99,7 +116,14 @@ def _test_field_removal(self, field):
self.assertEqual(changes[0].dependencies, [])
self.assertEqual(len(changes[0].operations), 1)
pre_operation = changes[0].operations[0]
self.assertIsInstance(pre_operation, PreRemoveField)
if field_db_default_supported:
self.assertIsInstance(pre_operation, migrations.AlterField)
if field.has_default():
self.assertEqual(pre_operation.field.db_default, 42)
else:
self.assertIs(pre_operation.field.null, True)
else:
self.assertIsInstance(pre_operation, PreRemoveField)
if not field.has_default():
self.assertIs(pre_operation.field.null, True)
self.assertEqual(get_migration_stage(changes[1]), Stage.POST_DEPLOY)
Expand Down Expand Up @@ -130,6 +154,19 @@ def test_nullable_field_removal(self):
self.assertEqual(len(changes), 1)
self.assertEqual(get_migration_stage(changes[0]), Stage.POST_DEPLOY)

@skipUnless(field_db_default_supported, "Field.db_default is not supported")
def test_db_default_field_removal(self):
"""
No action required if the field already has a `db_default`
"""
from_model = ModelState(
"tests", "Model", [("field", models.IntegerField(db_default=42))]
)
to_model = ModelState("tests", "Model", [])
changes = self.get_changes([from_model], [to_model])["tests"]
self.assertEqual(len(changes), 1)
self.assertEqual(get_migration_stage(changes[0]), Stage.POST_DEPLOY)

def test_non_nullable_field_removal_default(self):
from_model = ModelState("tests", "Model", [("field", models.IntegerField())])
to_model = ModelState("tests", "Model", [])
Expand Down
107 changes: 82 additions & 25 deletions tests/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Tuple
from unittest import mock, skipUnless

from django.db import connection, migrations, models
from django.db.migrations.operations.base import Operation
Expand All @@ -9,8 +10,14 @@
from django.test import TestCase

from syzygy.autodetector import MigrationAutodetector
from syzygy.compat import field_db_default_supported
from syzygy.constants import Stage
from syzygy.operations import AddField, PostAddField, PreRemoveField
from syzygy.operations import (
AddField,
PostAddField,
PreRemoveField,
get_pre_remove_field_operation,
)
from syzygy.plan import get_operation_stage


Expand Down Expand Up @@ -82,18 +89,38 @@ def test_deconstruct(self):
field_name = "foo"
field = models.IntegerField(default=42)
operation = AddField(model_name, field_name, field)
self.assertEqual(
operation.deconstruct(),
(
"AddField",
[],
{"model_name": model_name, "name": field_name, "field": field},
),
)
serializer = OperationSerializer(operation)
serialized, imports = serializer.serialize()
self.assertTrue(serialized.startswith("syzygy.operations.AddField"))
self.assertIn("import syzygy.operations", imports)
deconstructed = operation.deconstruct()
if field_db_default_supported:
self.assertEqual(
operation.deconstruct(),
(
"AddField",
[],
{"model_name": model_name, "name": field_name, "field": mock.ANY},
),
)
self.assertEqual(
deconstructed[2]["field"].deconstruct(),
(
None,
"django.db.models.IntegerField",
[],
{"default": 42, "db_default": 42},
),
)
else:
self.assertEqual(
deconstructed,
(
"AddField",
[],
{"model_name": model_name, "name": field_name, "field": field},
),
)
serializer = OperationSerializer(operation)
serialized, imports = serializer.serialize()
self.assertTrue(serialized.startswith("syzygy.operations.AddField"))
self.assertIn("import syzygy.operations", imports)


class PostAddFieldTests(OperationTestCase):
Expand Down Expand Up @@ -287,18 +314,38 @@ def test_deconstruct(self):
field_name,
field,
)
self.assertEqual(
operation.deconstruct(),
(
"PreRemoveField",
[],
{"model_name": model_name, "name": field_name, "field": field},
),
)
serializer = OperationSerializer(operation)
serialized, imports = serializer.serialize()
self.assertTrue(serialized.startswith("syzygy.operations.PreRemoveField"))
self.assertIn("import syzygy.operations", imports)
deconstructed = operation.deconstruct()
if field_db_default_supported:
self.assertEqual(
deconstructed,
(
"AlterField",
[],
{"model_name": model_name, "name": field_name, "field": mock.ANY},
),
)
self.assertEqual(
deconstructed[2]["field"].deconstruct(),
(
None,
"django.db.models.IntegerField",
[],
{"default": 42, "db_default": 42},
),
)
else:
self.assertEqual(
desconstructed,
(
"PreRemoveField",
[],
{"model_name": model_name, "name": field_name, "field": field},
),
)
serializer = OperationSerializer(operation)
serialized, imports = serializer.serialize()
self.assertTrue(serialized.startswith("syzygy.operations.PreRemoveField"))
self.assertIn("import syzygy.operations", imports)

def test_elidable(self):
model_name = "TestModel"
Expand All @@ -313,3 +360,13 @@ def test_elidable(self):
migrations.RemoveField(model_name, field_name, field),
]
self.assert_optimizes_to(operations, [operations[-1]])

@skipUnless(field_db_default_supported, "Field.db_default not supported")
def test_get_pre_remove_field_operation(self):
with self.assertRaisesMessage(
ValueError,
"Fields with a db_default don't require a pre-deployment operation.",
):
get_pre_remove_field_operation(
"model", "field", models.IntegerField(db_default=42)
)

0 comments on commit f64189d

Please sign in to comment.