Skip to content

Commit

Permalink
Contrib: Add a few SQLAlchemy patches and polyfills
Browse files Browse the repository at this point in the history
They do not fit well into the vanilla Python driver / SQLAlchemy
dialect.
  • Loading branch information
amotl committed Oct 10, 2023
1 parent fd4b7be commit 3cae25f
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
when the repository does not exist. With previous versions of CrateDB, it was
`RepositoryUnknownException`.

- Contrib: Add a few SQLAlchemy patches and polyfills, which do not fit well
into the vanilla Python driver / SQLAlchemy dialect.


## 2023/06/27 0.0.0

- Import "data retention" implementation from <https://github.com/crate/crate-airflow-tutorial>.
Expand Down
2 changes: 2 additions & 0 deletions cratedb_toolkit/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .patch import patch_inspector
from .polyfill import check_uniqueness_factory, polyfill_autoincrement, polyfill_refresh_after_dml, refresh_table
34 changes: 34 additions & 0 deletions cratedb_toolkit/sqlalchemy/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import typing as t

import sqlalchemy as sa


def patch_inspector():
"""
When using `get_table_names()`, make sure the correct schema name gets used.
Apparently, SQLAlchemy does not honor the `search_path` of the engine, when
using the inspector?
FIXME: Bug in CrateDB SQLAlchemy dialect?
"""

def get_effective_schema(engine: sa.Engine):
schema_name_raw = engine.url.query.get("schema")
schema_name = None
if isinstance(schema_name_raw, str):
schema_name = schema_name_raw
elif isinstance(schema_name_raw, tuple):
schema_name = schema_name_raw[0]
return schema_name

from crate.client.sqlalchemy.dialect import CrateDialect

get_table_names_dist = CrateDialect.get_table_names

def get_table_names(self, connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]:
if schema is None:
schema = get_effective_schema(connection.engine)
return get_table_names_dist(self, connection=connection, schema=schema, **kw)

CrateDialect.get_table_names = get_table_names # type: ignore
108 changes: 108 additions & 0 deletions cratedb_toolkit/sqlalchemy/polyfill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import itertools

import sqlalchemy as sa
from sqlalchemy.event import listen


def polyfill_autoincrement():
"""
Configure SQLAlchemy model columns with an alternative to `autoincrement=True`.
In this case, use a random identifier: Nagamani19, a short, unique,
non-sequential identifier based on Hashids.
TODO: Submit patch to `crate-python`, to be enabled by a
dialect parameter `crate_polyfill_autoincrement` or such.
"""
import sqlalchemy.sql.schema as schema
from sqlalchemy import func

init_dist = schema.Column.__init__

def __init__(self, *args, **kwargs):
if "autoincrement" in kwargs:
del kwargs["autoincrement"]
if "default" not in kwargs:
kwargs["default"] = func.now()
init_dist(self, *args, **kwargs)

schema.Column.__init__ = __init__ # type: ignore[method-assign]


def check_uniqueness_factory(sa_entity, attribute_name):
"""
Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable.
CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it.
TODO: Submit patch to `crate-python`, to be enabled by a
dialect parameter `crate_polyfill_unique` or such.
"""

def check_uniqueness(mapper, connection, target):
from sqlalchemy.exc import IntegrityError

if isinstance(target, sa_entity):
# TODO: How to use `session.query(SqlExperiment)` here?
stmt = (
mapper.selectable.select()
.filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name))
.compile(bind=connection.engine)
)
results = connection.execute(stmt)
if results.rowcount > 0:
raise IntegrityError(
statement=stmt,
params=[],
orig=Exception(f"DuplicateKeyException on column: {target.__tablename__}.{attribute_name}"),
)

return check_uniqueness


def polyfill_refresh_after_dml(session):
"""
Run `REFRESH TABLE <tablename>` after each INSERT, UPDATE, and DELETE operation.
CrateDB is eventually consistent, i.e. write operations are not flushed to
disk immediately, so readers may see stale data. In a traditional OLTP-like
application, this is not applicable.
This SQLAlchemy extension makes sure that data is synchronized after each
operation manipulating data.
> `after_{insert,update,delete}` events only apply to the session flush operation
> and do not apply to the ORM DML operations described at ORM-Enabled INSERT,
> UPDATE, and DELETE statements. To intercept ORM DML events, use
> `SessionEvents.do_orm_execute().`
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert
> Intercept statement executions that occur on behalf of an ORM Session object.
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute
> Execute after flush has completed, but before commit has been called.
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush
TODO: Submit patch to `crate-python`, to be enabled by a
dialect parameter `crate_dml_refresh` or such.
""" # noqa: E501
listen(session, "after_flush", do_flush)


def do_flush(session, flush_context):
"""
SQLAlchemy event handler for the 'after_flush' event,
invoking `REFRESH TABLE` on each table which has been modified.
"""
dirty_entities = itertools.chain(session.new, session.dirty, session.deleted)
dirty_classes = {entity.__class__ for entity in dirty_entities}
for class_ in dirty_classes:
refresh_table(session, class_)


def refresh_table(connection, target):
"""
Invoke a `REFRESH TABLE` statement.
"""
sql = f"REFRESH TABLE {target.__tablename__}"
connection.execute(sa.text(sql))
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ dependencies = [
"colorama<1",
"colorlog",
"crash",
"crate[sqlalchemy]",
"crate[sqlalchemy]>=0.34",
"sqlalchemy>=2",
]
[project.optional-dependencies]
Expand Down Expand Up @@ -216,6 +216,7 @@ extend-exclude = [
"tests/*" = ["S101"] # Allow use of `assert`, and `print`.
"examples/*" = ["T201"] # Allow `print`
"cratedb_toolkit/retention/cli.py" = ["T201"] # Allow `print`
"cratedb_toolkit/sqlalchemy/__init__.py" = ["F401"] # Allow `module´ imported but unused

[tool.setuptools.packages.find]
namespaces = false
Expand Down
Empty file added tests/sqlalchemy/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from cratedb_toolkit.util import DatabaseAdapter
from tests.conftest import TESTDRIVE_DATA_SCHEMA


@pytest.fixture
def database(cratedb):
"""
Provide a client database adapter, which is connected to the test database instance.
"""
database_url = cratedb.get_connection_url() + "?schema=" + TESTDRIVE_DATA_SCHEMA
yield DatabaseAdapter(dburi=database_url)
42 changes: 42 additions & 0 deletions tests/sqlalchemy/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sqlalchemy as sa

from cratedb_toolkit.sqlalchemy import patch_inspector
from tests.conftest import TESTDRIVE_DATA_SCHEMA


def test_inspector_vanilla(database):
"""
Vanilla SQLAlchemy Inspector tests.
"""
tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"'
inspector: sa.Inspector = sa.inspect(database.engine)
database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1")

assert inspector.has_schema(TESTDRIVE_DATA_SCHEMA) is True

table_names = inspector.get_table_names(schema=TESTDRIVE_DATA_SCHEMA)
assert table_names == ["foobar"]

view_names = inspector.get_view_names(schema=TESTDRIVE_DATA_SCHEMA)
assert view_names == []

indexes = inspector.get_indexes(tablename)
assert indexes == []


def test_inspector_patched(database):
"""
Patched SQLAlchemy Inspector tests.
Both MLflow and LangChain invoke `get_table_names()` without a `schema` argument.
This verifies that it still works, when it properly has been assigned to
the `?schema=` connection string URL parameter.
"""
patch_inspector()
tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"'
inspector: sa.Inspector = sa.inspect(database.engine)
database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1")
assert inspector.has_schema(TESTDRIVE_DATA_SCHEMA) is True

table_names = inspector.get_table_names()
assert table_names == ["foobar"]
109 changes: 109 additions & 0 deletions tests/sqlalchemy/test_polyfill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import re

import pytest
import sqlalchemy as sa

from cratedb_toolkit.sqlalchemy import check_uniqueness_factory, polyfill_autoincrement, polyfill_refresh_after_dml


def get_autoincrement_model():
"""
Provide a minimal SQLAlchemy model including an AUTOINCREMENT primary key.
"""
Base = sa.orm.declarative_base()

class FooBar(Base):
"""
Minimal SQLAlchemy model with autoincrement primary key.
"""

__tablename__ = "foobar"
identifier = sa.Column(sa.BigInteger, primary_key=True, autoincrement=True)
foo = sa.Column(sa.String)

return FooBar


def get_unique_model():
"""
Provide a minimal SQLAlchemy model including a column with UNIQUE constraint.
"""
Base = sa.orm.declarative_base()

class FooBar(Base):
"""
Minimal SQLAlchemy model with UNIQUE constraint.
"""

__tablename__ = "foobar"
identifier = sa.Column(sa.BigInteger, primary_key=True, default=sa.func.now())
name = sa.Column(sa.String, unique=True, nullable=False)

return FooBar


def test_autoincrement_vanilla(database):
"""
When using a model including an autoincrement column, and not assigning a value, CrateDB will fail.
"""
FooBar = get_autoincrement_model()
FooBar.metadata.create_all(database.engine)
with sa.orm.Session(database.engine) as session:
session.add(FooBar(foo="bar"))
with pytest.raises(sa.exc.ProgrammingError) as ex:
session.commit()
assert ex.match(
re.escape("SQLParseException[Column `identifier` is required but is missing from the insert statement]")
)


def test_autoincrement_polyfill(database):
"""
When using a model including an autoincrement column, and the corresponding polyfill
is installed, the procedure will succeed.
"""
polyfill_autoincrement()

FooBar = get_autoincrement_model()
FooBar.metadata.create_all(database.engine)
with sa.orm.Session(database.engine) as session:
session.add(FooBar(foo="bar"))
session.commit()


def test_unique_patched(database):
"""
When using a model including a column with UNIQUE constraint, the SQLAlchemy dialect will ignore it.
"""
FooBar = get_unique_model()
FooBar.metadata.create_all(database.engine)

with sa.orm.Session(database.engine) as session:
session.add(FooBar(name="name-1"))
session.commit()
session.add(FooBar(name="name-1"))
session.commit()


def test_unique_patched_and_active(database):
"""
When using a model including a column with UNIQUE constraint, enabling the patch,
and activating the uniqueness check, SQLAlchemy will raise `DuplicateKeyException`
errors if uniqueness constraints don't hold.
"""
FooBar = get_unique_model()
FooBar.metadata.create_all(database.engine)

# For uniqueness checks to take place, installing an event handler is needed.
# TODO: Maybe add to some helper function?
# TODO: Maybe derive from the model definition itself?
sa.event.listen(FooBar, "before_insert", check_uniqueness_factory(FooBar, "name"))

with sa.orm.Session(database.engine) as session:
polyfill_refresh_after_dml(session)
session.add(FooBar(name="name-1"))
session.commit()
session.add(FooBar(name="name-1"))
with pytest.raises(sa.exc.IntegrityError) as ex:
session.commit()
assert ex.match("DuplicateKeyException on column: foobar.name")

0 comments on commit 3cae25f

Please sign in to comment.