diff --git a/CHANGES.md b/CHANGES.md index 8ed58e9..576bba8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -7,6 +7,8 @@ - Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying [KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions, you can use it like `FloatVector(dimensions=1536)`. +- Fixed `get_table_names()` reflection method to respect the + `schema` query argument in SQLAlchemy connection URLs. [FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector [KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 73f6c53..425da36 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -229,6 +229,15 @@ def connect(self, host=None, port=None, *args, **kwargs): def _get_default_schema_name(self, connection): return 'doc' + def _get_effective_schema_name(self, connection): + schema_name_raw = connection.engine.url.query.get("schema") + schema_name = None + if isinstance(schema_name_raw, str): + schema_name = schema_name_raw + elif isinstance(schema_name_raw, tuple): + schema_name = schema_name_raw[0] + return schema_name + def _get_server_version_info(self, connection): return tuple(connection.connection.lowest_server_version.version) @@ -258,6 +267,8 @@ def get_schema_names(self, connection, **kw): @reflection.cache def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self._get_effective_schema_name(connection) cursor = connection.exec_driver_sql( "SELECT table_name FROM information_schema.tables " "WHERE {0} = ? "