diff --git a/CHANGES.md b/CHANGES.md index 037d278f..2542d704 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -20,6 +20,10 @@ when the repository does not exist. With previous versions of CrateDB, it was `RepositoryUnknownException`. +- Contrib: Add a few SQLAlchemy patches and polyfills, which do not fit well + into the vanilla Python driver / SQLAlchemy dialect. + + ## 2023/06/27 0.0.0 - Import "data retention" implementation from . diff --git a/cratedb_toolkit/sqlalchemy/__init__.py b/cratedb_toolkit/sqlalchemy/__init__.py new file mode 100644 index 00000000..ebff58bb --- /dev/null +++ b/cratedb_toolkit/sqlalchemy/__init__.py @@ -0,0 +1,2 @@ +from .patch import patch_inspector +from .polyfill import check_uniqueness_factory, polyfill_autoincrement, polyfill_refresh_after_dml, refresh_table diff --git a/cratedb_toolkit/sqlalchemy/patch.py b/cratedb_toolkit/sqlalchemy/patch.py new file mode 100644 index 00000000..78c89770 --- /dev/null +++ b/cratedb_toolkit/sqlalchemy/patch.py @@ -0,0 +1,34 @@ +import typing as t + +import sqlalchemy as sa + + +def patch_inspector(): + """ + When using `get_table_names()`, make sure the correct schema name gets used. + + Apparently, SQLAlchemy does not honor the `search_path` of the engine, when + using the inspector? + + FIXME: Bug in CrateDB SQLAlchemy dialect? + """ + + def get_effective_schema(engine: sa.Engine): + schema_name_raw = engine.url.query.get("schema") + schema_name = None + if isinstance(schema_name_raw, str): + schema_name = schema_name_raw + elif isinstance(schema_name_raw, tuple): + schema_name = schema_name_raw[0] + return schema_name + + from crate.client.sqlalchemy.dialect import CrateDialect + + get_table_names_dist = CrateDialect.get_table_names + + def get_table_names(self, connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]: + if schema is None: + schema = get_effective_schema(connection.engine) + return get_table_names_dist(self, connection=connection, schema=schema, **kw) + + CrateDialect.get_table_names = get_table_names # type: ignore diff --git a/cratedb_toolkit/sqlalchemy/polyfill.py b/cratedb_toolkit/sqlalchemy/polyfill.py new file mode 100644 index 00000000..079cf4cd --- /dev/null +++ b/cratedb_toolkit/sqlalchemy/polyfill.py @@ -0,0 +1,108 @@ +import itertools + +import sqlalchemy as sa +from sqlalchemy.event import listen + + +def polyfill_autoincrement(): + """ + Configure SQLAlchemy model columns with an alternative to `autoincrement=True`. + + In this case, use a random identifier: Nagamani19, a short, unique, + non-sequential identifier based on Hashids. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_polyfill_autoincrement` or such. + """ + import sqlalchemy.sql.schema as schema + from sqlalchemy import func + + init_dist = schema.Column.__init__ + + def __init__(self, *args, **kwargs): + if "autoincrement" in kwargs: + del kwargs["autoincrement"] + if "default" not in kwargs: + kwargs["default"] = func.now() + init_dist(self, *args, **kwargs) + + schema.Column.__init__ = __init__ # type: ignore[method-assign] + + +def check_uniqueness_factory(sa_entity, attribute_name): + """ + Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable. + + CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it. + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_polyfill_unique` or such. + """ + + def check_uniqueness(mapper, connection, target): + from sqlalchemy.exc import IntegrityError + + if isinstance(target, sa_entity): + # TODO: How to use `session.query(SqlExperiment)` here? + stmt = ( + mapper.selectable.select() + .filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name)) + .compile(bind=connection.engine) + ) + results = connection.execute(stmt) + if results.rowcount > 0: + raise IntegrityError( + statement=stmt, + params=[], + orig=Exception(f"DuplicateKeyException on column: {target.__tablename__}.{attribute_name}"), + ) + + return check_uniqueness + + +def polyfill_refresh_after_dml(session): + """ + Run `REFRESH TABLE ` after each INSERT, UPDATE, and DELETE operation. + + CrateDB is eventually consistent, i.e. write operations are not flushed to + disk immediately, so readers may see stale data. In a traditional OLTP-like + application, this is not applicable. + + This SQLAlchemy extension makes sure that data is synchronized after each + operation manipulating data. + + > `after_{insert,update,delete}` events only apply to the session flush operation + > and do not apply to the ORM DML operations described at ORM-Enabled INSERT, + > UPDATE, and DELETE statements. To intercept ORM DML events, use + > `SessionEvents.do_orm_execute().` + > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert + + > Intercept statement executions that occur on behalf of an ORM Session object. + > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute + + > Execute after flush has completed, but before commit has been called. + > -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush + + TODO: Submit patch to `crate-python`, to be enabled by a + dialect parameter `crate_dml_refresh` or such. + """ # noqa: E501 + listen(session, "after_flush", do_flush) + + +def do_flush(session, flush_context): + """ + SQLAlchemy event handler for the 'after_flush' event, + invoking `REFRESH TABLE` on each table which has been modified. + """ + dirty_entities = itertools.chain(session.new, session.dirty, session.deleted) + dirty_classes = {entity.__class__ for entity in dirty_entities} + for class_ in dirty_classes: + refresh_table(session, class_) + + +def refresh_table(connection, target): + """ + Invoke a `REFRESH TABLE` statement. + """ + sql = f"REFRESH TABLE {target.__tablename__}" + connection.execute(sa.text(sql)) diff --git a/pyproject.toml b/pyproject.toml index d605eb84..12910646 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dependencies = [ "colorama<1", "colorlog", "crash", - "crate[sqlalchemy]", + "crate[sqlalchemy]>=0.34", "sqlalchemy>=2", ] [project.optional-dependencies] @@ -216,6 +216,7 @@ extend-exclude = [ "tests/*" = ["S101"] # Allow use of `assert`, and `print`. "examples/*" = ["T201"] # Allow `print` "cratedb_toolkit/retention/cli.py" = ["T201"] # Allow `print` +"cratedb_toolkit/sqlalchemy/__init__.py" = ["F401"] # Allow `moduleĀ“ imported but unused [tool.setuptools.packages.find] namespaces = false diff --git a/tests/sqlalchemy/__init__.py b/tests/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py new file mode 100644 index 00000000..58b36ed5 --- /dev/null +++ b/tests/sqlalchemy/conftest.py @@ -0,0 +1,13 @@ +import pytest + +from cratedb_toolkit.util import DatabaseAdapter +from tests.conftest import TESTDRIVE_DATA_SCHEMA + + +@pytest.fixture +def database(cratedb): + """ + Provide a client database adapter, which is connected to the test database instance. + """ + database_url = cratedb.get_connection_url() + "?schema=" + TESTDRIVE_DATA_SCHEMA + yield DatabaseAdapter(dburi=database_url) diff --git a/tests/sqlalchemy/test_patch.py b/tests/sqlalchemy/test_patch.py new file mode 100644 index 00000000..1ed51552 --- /dev/null +++ b/tests/sqlalchemy/test_patch.py @@ -0,0 +1,42 @@ +import sqlalchemy as sa + +from cratedb_toolkit.sqlalchemy import patch_inspector +from tests.conftest import TESTDRIVE_DATA_SCHEMA + + +def test_inspector_vanilla(database): + """ + Vanilla SQLAlchemy Inspector tests. + """ + tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"' + inspector: sa.Inspector = sa.inspect(database.engine) + database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1") + + assert inspector.has_schema(TESTDRIVE_DATA_SCHEMA) is True + + table_names = inspector.get_table_names(schema=TESTDRIVE_DATA_SCHEMA) + assert table_names == ["foobar"] + + view_names = inspector.get_view_names(schema=TESTDRIVE_DATA_SCHEMA) + assert view_names == [] + + indexes = inspector.get_indexes(tablename) + assert indexes == [] + + +def test_inspector_patched(database): + """ + Patched SQLAlchemy Inspector tests. + + Both MLflow and LangChain invoke `get_table_names()` without a `schema` argument. + This verifies that it still works, when it properly has been assigned to + the `?schema=` connection string URL parameter. + """ + patch_inspector() + tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"' + inspector: sa.Inspector = sa.inspect(database.engine) + database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1") + assert inspector.has_schema(TESTDRIVE_DATA_SCHEMA) is True + + table_names = inspector.get_table_names() + assert table_names == ["foobar"] diff --git a/tests/sqlalchemy/test_polyfill.py b/tests/sqlalchemy/test_polyfill.py new file mode 100644 index 00000000..8599f4cf --- /dev/null +++ b/tests/sqlalchemy/test_polyfill.py @@ -0,0 +1,109 @@ +import re + +import pytest +import sqlalchemy as sa + +from cratedb_toolkit.sqlalchemy import check_uniqueness_factory, polyfill_autoincrement, polyfill_refresh_after_dml + + +def get_autoincrement_model(): + """ + Provide a minimal SQLAlchemy model including an AUTOINCREMENT primary key. + """ + Base = sa.orm.declarative_base() + + class FooBar(Base): + """ + Minimal SQLAlchemy model with autoincrement primary key. + """ + + __tablename__ = "foobar" + identifier = sa.Column(sa.BigInteger, primary_key=True, autoincrement=True) + foo = sa.Column(sa.String) + + return FooBar + + +def get_unique_model(): + """ + Provide a minimal SQLAlchemy model including a column with UNIQUE constraint. + """ + Base = sa.orm.declarative_base() + + class FooBar(Base): + """ + Minimal SQLAlchemy model with UNIQUE constraint. + """ + + __tablename__ = "foobar" + identifier = sa.Column(sa.BigInteger, primary_key=True, default=sa.func.now()) + name = sa.Column(sa.String, unique=True, nullable=False) + + return FooBar + + +def test_autoincrement_vanilla(database): + """ + When using a model including an autoincrement column, and not assigning a value, CrateDB will fail. + """ + FooBar = get_autoincrement_model() + FooBar.metadata.create_all(database.engine) + with sa.orm.Session(database.engine) as session: + session.add(FooBar(foo="bar")) + with pytest.raises(sa.exc.ProgrammingError) as ex: + session.commit() + assert ex.match( + re.escape("SQLParseException[Column `identifier` is required but is missing from the insert statement]") + ) + + +def test_autoincrement_polyfill(database): + """ + When using a model including an autoincrement column, and the corresponding polyfill + is installed, the procedure will succeed. + """ + polyfill_autoincrement() + + FooBar = get_autoincrement_model() + FooBar.metadata.create_all(database.engine) + with sa.orm.Session(database.engine) as session: + session.add(FooBar(foo="bar")) + session.commit() + + +def test_unique_patched(database): + """ + When using a model including a column with UNIQUE constraint, the SQLAlchemy dialect will ignore it. + """ + FooBar = get_unique_model() + FooBar.metadata.create_all(database.engine) + + with sa.orm.Session(database.engine) as session: + session.add(FooBar(name="name-1")) + session.commit() + session.add(FooBar(name="name-1")) + session.commit() + + +def test_unique_patched_and_active(database): + """ + When using a model including a column with UNIQUE constraint, enabling the patch, + and activating the uniqueness check, SQLAlchemy will raise `DuplicateKeyException` + errors if uniqueness constraints don't hold. + """ + FooBar = get_unique_model() + FooBar.metadata.create_all(database.engine) + + # For uniqueness checks to take place, installing an event handler is needed. + # TODO: Maybe add to some helper function? + # TODO: Maybe derive from the model definition itself? + sa.event.listen(FooBar, "before_insert", check_uniqueness_factory(FooBar, "name")) + + with sa.orm.Session(database.engine) as session: + polyfill_refresh_after_dml(session) + session.add(FooBar(name="name-1")) + session.commit() + session.add(FooBar(name="name-1")) + with pytest.raises(sa.exc.IntegrityError) as ex: + session.commit() + assert ex.match("DuplicateKeyException on column: foobar.name")