diff --git a/CHANGES.md b/CHANGES.md index ed4e3991..1a61d8af 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,9 @@ - Add support for Python 3.12 +- SQLAlchemy: Improve UNIQUE constraints polyfill to accept multiple + column names, for emulating unique composite keys. + ## 2023/10/10 0.0.1 diff --git a/cratedb_toolkit/sqlalchemy/polyfill.py b/cratedb_toolkit/sqlalchemy/polyfill.py index 079cf4cd..e8b3abfc 100644 --- a/cratedb_toolkit/sqlalchemy/polyfill.py +++ b/cratedb_toolkit/sqlalchemy/polyfill.py @@ -29,7 +29,7 @@ def __init__(self, *args, **kwargs): schema.Column.__init__ = __init__ # type: ignore[method-assign] -def check_uniqueness_factory(sa_entity, attribute_name): +def check_uniqueness_factory(sa_entity, *attribute_names): """ Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable. @@ -39,22 +39,27 @@ def check_uniqueness_factory(sa_entity, attribute_name): dialect parameter `crate_polyfill_unique` or such. """ + # Synthesize a canonical "name" for the constraint, + # composed of all column names involved. + constraint_name: str = "-".join(attribute_names) + def check_uniqueness(mapper, connection, target): from sqlalchemy.exc import IntegrityError if isinstance(target, sa_entity): # TODO: How to use `session.query(SqlExperiment)` here? - stmt = ( - mapper.selectable.select() - .filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name)) - .compile(bind=connection.engine) - ) + stmt = mapper.selectable.select() + for attribute_name in attribute_names: + stmt = stmt.filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name)) + stmt = stmt.compile(bind=connection.engine) results = connection.execute(stmt) if results.rowcount > 0: raise IntegrityError( statement=stmt, params=[], - orig=Exception(f"DuplicateKeyException on column: {target.__tablename__}.{attribute_name}"), + orig=Exception( + f"DuplicateKeyException in table '{target.__tablename__}' " f"on constraint '{constraint_name}'" + ), ) return check_uniqueness diff --git a/tests/conftest.py b/tests/conftest.py index f1c99799..44c280b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,8 @@ f'"{TESTDRIVE_DATA_SCHEMA}"."sensor_readings"', f'"{TESTDRIVE_DATA_SCHEMA}"."testdrive"', f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"', + f'"{TESTDRIVE_DATA_SCHEMA}"."foobar_unique_single"', + f'"{TESTDRIVE_DATA_SCHEMA}"."foobar_unique_composite"', ] diff --git a/tests/sqlalchemy/test_polyfill.py b/tests/sqlalchemy/test_polyfill.py index 8599f4cf..74ae6ef7 100644 --- a/tests/sqlalchemy/test_polyfill.py +++ b/tests/sqlalchemy/test_polyfill.py @@ -24,22 +24,42 @@ class FooBar(Base): return FooBar -def get_unique_model(): +def get_unique_model_single(): """ Provide a minimal SQLAlchemy model including a column with UNIQUE constraint. """ Base = sa.orm.declarative_base() - class FooBar(Base): + class FooBarSingle(Base): """ Minimal SQLAlchemy model with UNIQUE constraint. """ - __tablename__ = "foobar" + __tablename__ = "foobar_unique_single" identifier = sa.Column(sa.BigInteger, primary_key=True, default=sa.func.now()) name = sa.Column(sa.String, unique=True, nullable=False) - return FooBar + return FooBarSingle + + +def get_unique_model_composite(): + """ + Provide a minimal SQLAlchemy model using a composite UNIQUE constraint. + """ + Base = sa.orm.declarative_base() + + class FooBarComposite(Base): + """ + Minimal SQLAlchemy model with UNIQUE constraint. + """ + + __tablename__ = "foobar_unique_composite" + identifier = sa.Column(sa.BigInteger, primary_key=True, default=sa.func.now()) + name = sa.Column(sa.String, nullable=False) + user_id = sa.Column(sa.Integer, nullable=False) + __table_args__ = (sa.UniqueConstraint("name", "user_id", name="unique_name_user"),) + + return FooBarComposite def test_autoincrement_vanilla(database): @@ -75,7 +95,7 @@ def test_unique_patched(database): """ When using a model including a column with UNIQUE constraint, the SQLAlchemy dialect will ignore it. """ - FooBar = get_unique_model() + FooBar = get_unique_model_single() FooBar.metadata.create_all(database.engine) with sa.orm.Session(database.engine) as session: @@ -85,13 +105,13 @@ def test_unique_patched(database): session.commit() -def test_unique_patched_and_active(database): +def test_unique_patched_and_active_single(database): """ When using a model including a column with UNIQUE constraint, enabling the patch, and activating the uniqueness check, SQLAlchemy will raise `DuplicateKeyException` errors if uniqueness constraints don't hold. """ - FooBar = get_unique_model() + FooBar = get_unique_model_single() FooBar.metadata.create_all(database.engine) # For uniqueness checks to take place, installing an event handler is needed. @@ -106,4 +126,26 @@ def test_unique_patched_and_active(database): session.add(FooBar(name="name-1")) with pytest.raises(sa.exc.IntegrityError) as ex: session.commit() - assert ex.match("DuplicateKeyException on column: foobar.name") + assert ex.match("DuplicateKeyException in table 'foobar_unique_single' on constraint 'name'") + + +def test_unique_patched_and_active_composite(database): + """ + Similar to the _single variant, verify emulated **composite** UNIQUE constraints. + """ + FooBar = get_unique_model_composite() + FooBar.metadata.create_all(database.engine) + + # For uniqueness checks to take place, installing an event handler is needed. + # TODO: Maybe add to some helper function? + # TODO: Maybe derive from the model definition itself? + sa.event.listen(FooBar, "before_insert", check_uniqueness_factory(FooBar, "name", "user_id")) + + with sa.orm.Session(database.engine) as session: + polyfill_refresh_after_dml(session) + session.add(FooBar(name="name-1", user_id=1)) + session.commit() + session.add(FooBar(name="name-1", user_id=1)) + with pytest.raises(sa.exc.IntegrityError) as ex: + session.commit() + assert ex.match("DuplicateKeyException in table 'foobar_unique_composite' on constraint 'name-user_id'")