Skip to content

Commit

Permalink
get_table_names(): Fix reflection method to respect schema query ar…
Browse files Browse the repository at this point in the history
…gument

It did not respect the `schema` query argument in SQLAlchemy connection
URLs.

Co-authored-by: Marios Trivyzas <[email protected]>
  • Loading branch information
amotl and matriv committed Jun 10, 2024
1 parent 20443d2 commit edc6f4d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
- Added support for CrateDB's [FLOAT_VECTOR] data type and its accompanying
[KNN_MATCH] function, for HNSW matches. For SQLAlchemy column definitions,
you can use it like `FloatVector(dimensions=1536)`.
- Fixed `get_table_names()` reflection method to respect the
`schema` query argument in SQLAlchemy connection URLs.

[FLOAT_VECTOR]: https://cratedb.com/docs/crate/reference/en/latest/general/ddl/data-types.html#float-vector
[KNN_MATCH]: https://cratedb.com/docs/crate/reference/en/latest/general/builtins/scalar-functions.html#scalar-knn-match
Expand Down
11 changes: 11 additions & 0 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ def connect(self, host=None, port=None, *args, **kwargs):
def _get_default_schema_name(self, connection):
return 'doc'

def _get_effective_schema_name(self, connection):
schema_name_raw = connection.engine.url.query.get("schema")
schema_name = None
if isinstance(schema_name_raw, str):
schema_name = schema_name_raw
elif isinstance(schema_name_raw, tuple):
schema_name = schema_name_raw[0]
return schema_name

def _get_server_version_info(self, connection):
return tuple(connection.connection.lowest_server_version.version)

Expand Down Expand Up @@ -258,6 +267,8 @@ def get_schema_names(self, connection, **kw):

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self._get_effective_schema_name(connection)
cursor = connection.exec_driver_sql(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
Expand Down

0 comments on commit edc6f4d

Please sign in to comment.