Skip to content

Commit

Permalink
Fix sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartmcalpine committed Dec 13, 2024
1 parent ab36caf commit 000c36d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/dataregistry/db_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _reflect(self):
self.metadata["schema_version"] = f"{r[0]}.{r[1]}.{r[2]}"

# Add production schema tables to metadata
if self._prod_schema is not None:
if self._prod_schema is not None and self.dialect != "sqlite":
metadata.reflect(self.engine, self._prod_schema)
cols.remove("associated_production")
prov_name = ".".join([self._prod_schema, "provenance"])
Expand Down
18 changes: 10 additions & 8 deletions src/dataregistry/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def _parse_selected_columns(self, column_names):
if column_names is None:
column_names = []
for table in self.db_connection.metadata["tables"]:
if table.split(".")[1] == "dataset":
tname = table if self.db_connection.dialect == "sqlite" else table.split(".")[1]
if tname == "dataset":
column_names.extend(
[
x.table.name + "." + x.name
Expand Down Expand Up @@ -419,6 +420,7 @@ def find_datasets(

# Construct query
for schema in column_list.keys(): # Loop over each schema
schema_str = "" if self.db_connection.dialect == "sqlite" else f"{schema}."
columns = [f"{p.table.name}.{p.name}" for p in column_list[schema]]

stmt = select(
Expand All @@ -427,29 +429,29 @@ def find_datasets(

# Create joins
if len(tables_required) > 1:
j = self.db_connection.metadata["tables"][f"{schema}.dataset"]
j = self.db_connection.metadata["tables"][f"{schema_str}dataset"]
for i in range(len(tables_required)):
if tables_required[i] in ["dataset", "keyword", "dependency"]:
continue

j = j.join(
self.db_connection.metadata["tables"][
f"{schema}.{tables_required[i]}"
f"{schema_str}{tables_required[i]}"
]
)

# Special case for many-to-many keyword join
if "keyword" in tables_required:
j = j.join(
self.db_connection.metadata["tables"][
f"{schema}.dataset_keyword"
f"{schema_str}dataset_keyword"
]
).join(self.db_connection.metadata["tables"][f"{schema}.keyword"])
).join(self.db_connection.metadata["tables"][f"{schema_str}keyword"])

# Special case for dependencies
if "dependency" in tables_required:
dataset_table = self.db_connection.metadata["tables"][f"{schema}.dataset"]
dependency_table = self.db_connection.metadata["tables"][f"{schema}.dependency"]
dataset_table = self.db_connection.metadata["tables"][f"{schema_str}dataset"]
dependency_table = self.db_connection.metadata["tables"][f"{schema_str}dependency"]

j = j.join(
dependency_table,
Expand All @@ -460,7 +462,7 @@ def find_datasets(
else:
stmt = stmt.select_from(
self.db_connection.metadata["tables"][
f"{schema}.{tables_required[0]}"
f"{schema_str}{tables_required[0]}"
]
)

Expand Down
3 changes: 2 additions & 1 deletion src/dataregistry/registrar/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,9 +1016,10 @@ def add_keywords(self, dataset_id, keywords):
)

result = conn.execute(stmt)
rows = result.fetchall()

# If we don't have the keyword, add it
if result.rowcount == 0:
if len(rows) == 0:
add_table_row(
conn,
dataset_keyword_table,
Expand Down

0 comments on commit 000c36d

Please sign in to comment.