Skip to content

Commit

Permalink
refactor: optimize m2m add logic (#1620)
Browse files Browse the repository at this point in the history
* refactor: optimize m2m add logic

* Share logic with clear and remove for many2many relation

* Add create_unique_index argument to ManyToManyField

* Upgrade dependencies

* Use getenv instead

* fixing mysql schema generate error

* Success to pass mysql test

* Upgrade coveralls

* Update deps and changelog

* Upgrade coveralls
  • Loading branch information
waketzheng authored May 25, 2024
1 parent 7deda88 commit 8f83a75
Show file tree
Hide file tree
Showing 9 changed files with 528 additions and 479 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ Changelog
0.21
====

0.21.2
------
Added
^^^^^
- Add `create_unique_index` argument to M2M field and default if it is true (#1620)

0.21.1
------
Fixed
Expand Down
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ def initialize_tests(request):
except ImportError:
pass

db_url = os.environ.get("TORTOISE_TEST_DB", "sqlite://:memory:")
db_url = os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:")
initializer(["tests.testmodels"], db_url=db_url)
request.addfinalizer(finalizer)
752 changes: 390 additions & 362 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tortoise-orm"
version = "0.21.1"
version = "0.21.2"
description = "Easy async ORM for python, built with relations in mind"
authors = ["Andrey Bondar <[email protected]>", "Nickolas Grigoriadis <[email protected]>", "long2ice <[email protected]>"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion tests/fields/test_db_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class TestIndexAlias(test.TestCase):
Field: Any = fields.IntField

def test_index_alias(self):
def test_index_alias(self) -> None:
kwargs: dict = getattr(self, "init_kwargs", {})
with self.assertWarnsRegex(
DeprecationWarning, "`index` is deprecated, please use `db_index` instead"
Expand Down
32 changes: 27 additions & 5 deletions tests/schema/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,17 @@ class TestGenerateSchema(test.SimpleTestCase):
"backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE,
"sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_sometable_s_backwar_fc8fc8" ON "sometable_self" ("backward_sts", "sts_forward");
CREATE TABLE IF NOT EXISTS "team_team" (
"team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
CREATE TABLE IF NOT EXISTS "teamevents" (
"event_id" BIGINT NOT NULL REFERENCES "event" ("id") ON DELETE SET NULL,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE SET NULL
) /* How participants relate */;
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");
""".strip()

async def asyncSetUp(self):
Expand Down Expand Up @@ -250,10 +253,12 @@ async def test_schema_no_db_constraint(self):
"team_rel_id" VARCHAR(50) NOT NULL,
"team_id" VARCHAR(50) NOT NULL
);
CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
CREATE TABLE "teamevents" (
"event_id" BIGINT NOT NULL,
"team_id" VARCHAR(50) NOT NULL
) /* How participants relate */;""",
) /* How participants relate */;
CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");""",
)

async def test_schema(self):
Expand Down Expand Up @@ -332,14 +337,17 @@ async def test_schema(self):
"backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE,
"sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE
);
CREATE UNIQUE INDEX "uidx_sometable_s_backwar_fc8fc8" ON "sometable_self" ("backward_sts", "sts_forward");
CREATE TABLE "team_team" (
"team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE
);
CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
CREATE TABLE "teamevents" (
"event_id" BIGINT NOT NULL REFERENCES "event" ("id") ON DELETE SET NULL,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE SET NULL
) /* How participants relate */;
CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");
""".strip(),
)

Expand Down Expand Up @@ -389,7 +397,9 @@ async def test_m2m_no_auto_create(self):
CREATE TABLE "team_team" (
"team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE
);""".strip(),
);
CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
""".strip(),
)


Expand Down Expand Up @@ -772,7 +782,9 @@ async def test_m2m_no_auto_create(self):
`team_id` VARCHAR(50) NOT NULL,
FOREIGN KEY (`team_rel_id`) REFERENCES `team` (`name`) ON DELETE CASCADE,
FOREIGN KEY (`team_id`) REFERENCES `team` (`name`) ON DELETE CASCADE
) CHARACTER SET utf8mb4;""".strip(),
) CHARACTER SET utf8mb4;
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
""".strip(),
)


Expand Down Expand Up @@ -842,11 +854,13 @@ async def test_schema_no_db_constraint(self):
"team_rel_id" VARCHAR(50) NOT NULL,
"team_id" VARCHAR(50) NOT NULL
);
CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
CREATE TABLE "teamevents" (
"event_id" BIGINT NOT NULL,
"team_id" VARCHAR(50) NOT NULL
);
COMMENT ON TABLE "teamevents" IS 'How participants relate';""",
COMMENT ON TABLE "teamevents" IS 'How participants relate';
CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");""",
)

async def test_schema(self):
Expand Down Expand Up @@ -939,15 +953,18 @@ async def test_schema(self):
"backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE,
"sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE
);
CREATE UNIQUE INDEX "uidx_sometable_s_backwar_fc8fc8" ON "sometable_self" ("backward_sts", "sts_forward");
CREATE TABLE "team_team" (
"team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE
);
CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
CREATE TABLE "teamevents" (
"event_id" BIGINT NOT NULL REFERENCES "event" ("id") ON DELETE SET NULL,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE SET NULL
);
COMMENT ON TABLE "teamevents" IS 'How participants relate';
CREATE UNIQUE INDEX "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");
""".strip(),
)

Expand Down Expand Up @@ -1041,15 +1058,18 @@ async def test_schema_safe(self):
"backward_sts" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE,
"sts_forward" INT NOT NULL REFERENCES "sometable" ("sometable_id") ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_sometable_s_backwar_fc8fc8" ON "sometable_self" ("backward_sts", "sts_forward");
CREATE TABLE IF NOT EXISTS "team_team" (
"team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
CREATE TABLE IF NOT EXISTS "teamevents" (
"event_id" BIGINT NOT NULL REFERENCES "event" ("id") ON DELETE SET NULL,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE SET NULL
);
COMMENT ON TABLE "teamevents" IS 'How participants relate';
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_teamevents_event_i_664dbc" ON "teamevents" ("event_id", "team_id");
""".strip(),
)

Expand Down Expand Up @@ -1151,7 +1171,9 @@ async def test_m2m_no_auto_create(self):
CREATE TABLE "team_team" (
"team_rel_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE,
"team_id" VARCHAR(50) NOT NULL REFERENCES "team" ("name") ON DELETE CASCADE
);""".strip(),
);
CREATE UNIQUE INDEX "uidx_team_team_team_re_d994df" ON "team_team" ("team_rel_id", "team_id");
""".strip(),
)


Expand Down
84 changes: 53 additions & 31 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from hashlib import sha256
from typing import TYPE_CHECKING, Any, List, Set, Type, cast
from typing import TYPE_CHECKING, Any, List, Set, Type, Union, cast

from tortoise.exceptions import ConfigurationError
from tortoise.fields import JSONField, TextField, UUIDField
Expand All @@ -19,6 +20,7 @@ class BaseSchemaGenerator:
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};'
FIELD_TEMPLATE = '"{name}" {type} {nullable} {unique}{primary}{default}{comment}'
INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});'
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace(" INDEX", " UNIQUE INDEX")
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}'
FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}'
Expand Down Expand Up @@ -134,12 +136,12 @@ def _make_hash(*args: str, length: int) -> str:
return sha256(";".join(args).encode("utf-8")).hexdigest()[:length]

def _generate_index_name(
self, prefix: str, model: "Type[Model]", field_names: List[str]
self, prefix: str, model: "Union[Type[Model], str]", field_names: List[str]
) -> str:
# NOTE: for compatibility, index name should not be longer than 30
# characters (Oracle limit).
# That's why we slice some of the strings here.
table_name = model._meta.db_table
table_name = model if isinstance(model, str) else model._meta.db_table
index_name = "{}_{}_{}_{}".format(
prefix,
table_name[:11],
Expand Down Expand Up @@ -169,6 +171,15 @@ def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: boo
fields=", ".join([self.quote(f) for f in field_names]),
)

def _get_unique_index_sql(self, exists: str, table_name: str, field_names: List[str]) -> str:
index_name = self._generate_index_name("uidx", table_name, field_names)
return self.UNIQUE_INDEX_CREATE_TEMPLATE.format(
exists=exists,
index_name=index_name,
table_name=table_name,
fields=", ".join([self.quote(f) for f in field_names]),
)

def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: List[str]) -> str:
return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format(
index_name=self._generate_index_name("uid", model, field_names),
Expand Down Expand Up @@ -348,36 +359,35 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
if field_object._generated or field_object.through in models_tables:
continue
backward_key, forward_key = field_object.backward_key, field_object.forward_key
backward_fk = forward_fk = ""
if field_object.db_constraint:
backward_fk = self._create_fk_string(
"",
backward_key,
model._meta.db_table,
model._meta.db_pk_column,
field_object.on_delete,
"",
)
forward_fk = self._create_fk_string(
"",
forward_key,
field_object.related_model._meta.db_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
"",
)
exists = "IF NOT EXISTS " if safe else ""
table_name = field_object.through
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists="IF NOT EXISTS " if safe else "",
table_name=field_object.through,
backward_fk=(
self._create_fk_string(
"",
field_object.backward_key,
model._meta.db_table,
model._meta.db_pk_column,
field_object.on_delete,
"",
)
if field_object.db_constraint
else ""
),
forward_fk=(
self._create_fk_string(
"",
field_object.forward_key,
field_object.related_model._meta.db_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
"",
)
if field_object.db_constraint
else ""
),
backward_key=field_object.backward_key,
exists=exists,
table_name=table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
backward_key=backward_key,
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field_object.forward_key,
forward_key=forward_key,
forward_type=field_object.related_model._meta.pk.get_for_dialect(
self.DIALECT, "SQL_TYPE"
),
Expand All @@ -398,6 +408,18 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
"",
) # may have better way
m2m_create_string += self._post_table_hook()
if field_object.create_unique_index:
unique_index_create_sql = self._get_unique_index_sql(
exists, table_name, [backward_key, forward_key]
)
if unique_index_create_sql.endswith(";"):
m2m_create_string += "\n" + unique_index_create_sql
else:
lines = m2m_create_string.splitlines()
lines[-2] += ","
indent = m.group() if (m := re.match(r"\s+", lines[-2])) else ""
lines.insert(-1, indent + unique_index_create_sql)
m2m_create_string = "\n".join(lines)
m2m_tables_for_create.append(m2m_create_string)

return {
Expand Down
12 changes: 3 additions & 9 deletions tortoise/backends/mysql/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class MySQLSchemaGenerator(BaseSchemaGenerator):
TABLE_CREATE_TEMPLATE = "CREATE TABLE {exists}`{table_name}` ({fields}){extra}{comment};"
INDEX_CREATE_TEMPLATE = "KEY `{index_name}` ({fields})"
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = "UNIQUE KEY `{index_name}` ({fields})"
UNIQUE_INDEX_CREATE_TEMPLATE = UNIQUE_CONSTRAINT_CREATE_TEMPLATE
FIELD_TEMPLATE = "`{name}` {type} {nullable} {unique}{primary}{comment}{default}"
GENERATED_PK_TEMPLATE = "`{field_name}` {generated_sql}{comment}"
FK_TEMPLATE = (
Expand Down Expand Up @@ -69,15 +70,8 @@ def _escape_default_value(self, default: Any):

def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str:
"""Get index SQLs, but keep them for ourselves"""
self._field_indexes.append(
self.INDEX_CREATE_TEMPLATE.format(
exists="IF NOT EXISTS " if safe else "",
index_name=self._generate_index_name("idx", model, field_names),
table_name=model._meta.db_table,
fields=", ".join(self.quote(f) for f in field_names),
)
)

index_create_sql = super()._get_index_sql(model, field_names, safe)
self._field_indexes.append(index_create_sql)
return ""

def _create_fk_string(
Expand Down
Loading

0 comments on commit 8f83a75

Please sign in to comment.