-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Contrib: Add a few SQLAlchemy patches and polyfills
They do not fit well into the vanilla Python driver / SQLAlchemy dialect.
- Loading branch information
Showing
9 changed files
with
314 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .patch import patch_inspector | ||
from .polyfill import check_uniqueness_factory, polyfill_autoincrement, polyfill_refresh_after_dml, refresh_table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import typing as t | ||
|
||
import sqlalchemy as sa | ||
|
||
|
||
def patch_inspector(): | ||
""" | ||
When using `get_table_names()`, make sure the correct schema name gets used. | ||
Apparently, SQLAlchemy does not honor the `search_path` of the engine, when | ||
using the inspector? | ||
FIXME: Bug in CrateDB SQLAlchemy dialect? | ||
""" | ||
|
||
def get_effective_schema(engine: sa.Engine): | ||
schema_name_raw = engine.url.query.get("schema") | ||
schema_name = None | ||
if isinstance(schema_name_raw, str): | ||
schema_name = schema_name_raw | ||
elif isinstance(schema_name_raw, tuple): | ||
schema_name = schema_name_raw[0] | ||
return schema_name | ||
|
||
from crate.client.sqlalchemy.dialect import CrateDialect | ||
|
||
get_table_names_dist = CrateDialect.get_table_names | ||
|
||
def get_table_names(self, connection: sa.Connection, schema: t.Optional[str] = None, **kw: t.Any) -> t.List[str]: | ||
if schema is None: | ||
schema = get_effective_schema(connection.engine) | ||
return get_table_names_dist(self, connection=connection, schema=schema, **kw) | ||
|
||
CrateDialect.get_table_names = get_table_names # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import itertools | ||
|
||
import sqlalchemy as sa | ||
from sqlalchemy.event import listen | ||
|
||
|
||
def polyfill_autoincrement(): | ||
""" | ||
Configure SQLAlchemy model columns with an alternative to `autoincrement=True`. | ||
In this case, use a random identifier: Nagamani19, a short, unique, | ||
non-sequential identifier based on Hashids. | ||
TODO: Submit patch to `crate-python`, to be enabled by a | ||
dialect parameter `crate_polyfill_autoincrement` or such. | ||
""" | ||
import sqlalchemy.sql.schema as schema | ||
from sqlalchemy import func | ||
|
||
init_dist = schema.Column.__init__ | ||
|
||
def __init__(self, *args, **kwargs): | ||
if "autoincrement" in kwargs: | ||
del kwargs["autoincrement"] | ||
if "default" not in kwargs: | ||
kwargs["default"] = func.now() | ||
init_dist(self, *args, **kwargs) | ||
|
||
schema.Column.__init__ = __init__ # type: ignore[method-assign] | ||
|
||
|
||
def check_uniqueness_factory(sa_entity, attribute_name): | ||
""" | ||
Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable. | ||
CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it. | ||
TODO: Submit patch to `crate-python`, to be enabled by a | ||
dialect parameter `crate_polyfill_unique` or such. | ||
""" | ||
|
||
def check_uniqueness(mapper, connection, target): | ||
from sqlalchemy.exc import IntegrityError | ||
|
||
if isinstance(target, sa_entity): | ||
# TODO: How to use `session.query(SqlExperiment)` here? | ||
stmt = ( | ||
mapper.selectable.select() | ||
.filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name)) | ||
.compile(bind=connection.engine) | ||
) | ||
results = connection.execute(stmt) | ||
if results.rowcount > 0: | ||
raise IntegrityError( | ||
statement=stmt, | ||
params=[], | ||
orig=Exception(f"DuplicateKeyException on column: {target.__tablename__}.{attribute_name}"), | ||
) | ||
|
||
return check_uniqueness | ||
|
||
|
||
def polyfill_refresh_after_dml(session): | ||
""" | ||
Run `REFRESH TABLE <tablename>` after each INSERT, UPDATE, and DELETE operation. | ||
CrateDB is eventually consistent, i.e. write operations are not flushed to | ||
disk immediately, so readers may see stale data. In a traditional OLTP-like | ||
application, this is not applicable. | ||
This SQLAlchemy extension makes sure that data is synchronized after each | ||
operation manipulating data. | ||
> `after_{insert,update,delete}` events only apply to the session flush operation | ||
> and do not apply to the ORM DML operations described at ORM-Enabled INSERT, | ||
> UPDATE, and DELETE statements. To intercept ORM DML events, use | ||
> `SessionEvents.do_orm_execute().` | ||
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert | ||
> Intercept statement executions that occur on behalf of an ORM Session object. | ||
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute | ||
> Execute after flush has completed, but before commit has been called. | ||
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush | ||
TODO: Submit patch to `crate-python`, to be enabled by a | ||
dialect parameter `crate_dml_refresh` or such. | ||
""" # noqa: E501 | ||
listen(session, "after_flush", do_flush) | ||
|
||
|
||
def do_flush(session, flush_context): | ||
""" | ||
SQLAlchemy event handler for the 'after_flush' event, | ||
invoking `REFRESH TABLE` on each table which has been modified. | ||
""" | ||
dirty_entities = itertools.chain(session.new, session.dirty, session.deleted) | ||
dirty_classes = {entity.__class__ for entity in dirty_entities} | ||
for class_ in dirty_classes: | ||
refresh_table(session, class_) | ||
|
||
|
||
def refresh_table(connection, target): | ||
""" | ||
Invoke a `REFRESH TABLE` statement. | ||
""" | ||
sql = f"REFRESH TABLE {target.__tablename__}" | ||
connection.execute(sa.text(sql)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import pytest | ||
|
||
from cratedb_toolkit.util import DatabaseAdapter | ||
from tests.conftest import TESTDRIVE_DATA_SCHEMA | ||
|
||
|
||
@pytest.fixture | ||
def database(cratedb): | ||
""" | ||
Provide a client database adapter, which is connected to the test database instance. | ||
""" | ||
database_url = cratedb.get_connection_url() + "?schema=" + TESTDRIVE_DATA_SCHEMA | ||
yield DatabaseAdapter(dburi=database_url) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import sqlalchemy as sa | ||
|
||
from cratedb_toolkit.sqlalchemy import patch_inspector | ||
from tests.conftest import TESTDRIVE_DATA_SCHEMA | ||
|
||
|
||
def test_inspector_vanilla(database): | ||
""" | ||
Vanilla SQLAlchemy Inspector tests. | ||
""" | ||
tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"' | ||
inspector: sa.Inspector = sa.inspect(database.engine) | ||
database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1") | ||
|
||
assert inspector.has_schema(TESTDRIVE_DATA_SCHEMA) is True | ||
|
||
table_names = inspector.get_table_names(schema=TESTDRIVE_DATA_SCHEMA) | ||
assert table_names == ["foobar"] | ||
|
||
view_names = inspector.get_view_names(schema=TESTDRIVE_DATA_SCHEMA) | ||
assert view_names == [] | ||
|
||
indexes = inspector.get_indexes(tablename) | ||
assert indexes == [] | ||
|
||
|
||
def test_inspector_patched(database): | ||
""" | ||
Patched SQLAlchemy Inspector tests. | ||
Both MLflow and LangChain invoke `get_table_names()` without a `schema` argument. | ||
This verifies that it still works, when it properly has been assigned to | ||
the `?schema=` connection string URL parameter. | ||
""" | ||
patch_inspector() | ||
tablename = f'"{TESTDRIVE_DATA_SCHEMA}"."foobar"' | ||
inspector: sa.Inspector = sa.inspect(database.engine) | ||
database.run_sql(f"CREATE TABLE {tablename} AS SELECT 1") | ||
assert inspector.has_schema(TESTDRIVE_DATA_SCHEMA) is True | ||
|
||
table_names = inspector.get_table_names() | ||
assert table_names == ["foobar"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import re | ||
|
||
import pytest | ||
import sqlalchemy as sa | ||
|
||
from cratedb_toolkit.sqlalchemy import check_uniqueness_factory, polyfill_autoincrement, polyfill_refresh_after_dml | ||
|
||
|
||
def get_autoincrement_model(): | ||
""" | ||
Provide a minimal SQLAlchemy model including an AUTOINCREMENT primary key. | ||
""" | ||
Base = sa.orm.declarative_base() | ||
|
||
class FooBar(Base): | ||
""" | ||
Minimal SQLAlchemy model with autoincrement primary key. | ||
""" | ||
|
||
__tablename__ = "foobar" | ||
identifier = sa.Column(sa.BigInteger, primary_key=True, autoincrement=True) | ||
foo = sa.Column(sa.String) | ||
|
||
return FooBar | ||
|
||
|
||
def get_unique_model(): | ||
""" | ||
Provide a minimal SQLAlchemy model including a column with UNIQUE constraint. | ||
""" | ||
Base = sa.orm.declarative_base() | ||
|
||
class FooBar(Base): | ||
""" | ||
Minimal SQLAlchemy model with UNIQUE constraint. | ||
""" | ||
|
||
__tablename__ = "foobar" | ||
identifier = sa.Column(sa.BigInteger, primary_key=True, default=sa.func.now()) | ||
name = sa.Column(sa.String, unique=True, nullable=False) | ||
|
||
return FooBar | ||
|
||
|
||
def test_autoincrement_vanilla(database): | ||
""" | ||
When using a model including an autoincrement column, and not assigning a value, CrateDB will fail. | ||
""" | ||
FooBar = get_autoincrement_model() | ||
FooBar.metadata.create_all(database.engine) | ||
with sa.orm.Session(database.engine) as session: | ||
session.add(FooBar(foo="bar")) | ||
with pytest.raises(sa.exc.ProgrammingError) as ex: | ||
session.commit() | ||
assert ex.match( | ||
re.escape("SQLParseException[Column `identifier` is required but is missing from the insert statement]") | ||
) | ||
|
||
|
||
def test_autoincrement_polyfill(database): | ||
""" | ||
When using a model including an autoincrement column, and the corresponding polyfill | ||
is installed, the procedure will succeed. | ||
""" | ||
polyfill_autoincrement() | ||
|
||
FooBar = get_autoincrement_model() | ||
FooBar.metadata.create_all(database.engine) | ||
with sa.orm.Session(database.engine) as session: | ||
session.add(FooBar(foo="bar")) | ||
session.commit() | ||
|
||
|
||
def test_unique_patched(database): | ||
""" | ||
When using a model including a column with UNIQUE constraint, the SQLAlchemy dialect will ignore it. | ||
""" | ||
FooBar = get_unique_model() | ||
FooBar.metadata.create_all(database.engine) | ||
|
||
with sa.orm.Session(database.engine) as session: | ||
session.add(FooBar(name="name-1")) | ||
session.commit() | ||
session.add(FooBar(name="name-1")) | ||
session.commit() | ||
|
||
|
||
def test_unique_patched_and_active(database): | ||
""" | ||
When using a model including a column with UNIQUE constraint, enabling the patch, | ||
and activating the uniqueness check, SQLAlchemy will raise `DuplicateKeyException` | ||
errors if uniqueness constraints don't hold. | ||
""" | ||
FooBar = get_unique_model() | ||
FooBar.metadata.create_all(database.engine) | ||
|
||
# For uniqueness checks to take place, installing an event handler is needed. | ||
# TODO: Maybe add to some helper function? | ||
# TODO: Maybe derive from the model definition itself? | ||
sa.event.listen(FooBar, "before_insert", check_uniqueness_factory(FooBar, "name")) | ||
|
||
with sa.orm.Session(database.engine) as session: | ||
polyfill_refresh_after_dml(session) | ||
session.add(FooBar(name="name-1")) | ||
session.commit() | ||
session.add(FooBar(name="name-1")) | ||
with pytest.raises(sa.exc.IntegrityError) as ex: | ||
session.commit() | ||
assert ex.match("DuplicateKeyException on column: foobar.name") |