From 000c36da7c6d2d38f9f95734a60565ad7bbf87bc Mon Sep 17 00:00:00 2001 From: Stuart McAlpine Date: Fri, 13 Dec 2024 16:44:19 +0100 Subject: [PATCH] Fix sqlite --- src/dataregistry/db_basic.py | 2 +- src/dataregistry/query.py | 18 ++++++++++-------- src/dataregistry/registrar/dataset.py | 3 ++- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/dataregistry/db_basic.py b/src/dataregistry/db_basic.py index a77751d..329571f 100644 --- a/src/dataregistry/db_basic.py +++ b/src/dataregistry/db_basic.py @@ -221,7 +221,7 @@ def _reflect(self): self.metadata["schema_version"] = f"{r[0]}.{r[1]}.{r[2]}" # Add production schema tables to metadata - if self._prod_schema is not None: + if self._prod_schema is not None and self.dialect != "sqlite": metadata.reflect(self.engine, self._prod_schema) cols.remove("associated_production") prov_name = ".".join([self._prod_schema, "provenance"]) diff --git a/src/dataregistry/query.py b/src/dataregistry/query.py index 916c622..a0aa6fc 100644 --- a/src/dataregistry/query.py +++ b/src/dataregistry/query.py @@ -176,7 +176,8 @@ def _parse_selected_columns(self, column_names): if column_names is None: column_names = [] for table in self.db_connection.metadata["tables"]: - if table.split(".")[1] == "dataset": + tname = table if self.db_connection.dialect == "sqlite" else table.split(".")[1] + if tname == "dataset": column_names.extend( [ x.table.name + "." + x.name @@ -419,6 +420,7 @@ def find_datasets( # Construct query for schema in column_list.keys(): # Loop over each schema + schema_str = "" if self.db_connection.dialect == "sqlite" else f"{schema}." columns = [f"{p.table.name}.{p.name}" for p in column_list[schema]] stmt = select( @@ -427,14 +429,14 @@ def find_datasets( # Create joins if len(tables_required) > 1: - j = self.db_connection.metadata["tables"][f"{schema}.dataset"] + j = self.db_connection.metadata["tables"][f"{schema_str}dataset"] for i in range(len(tables_required)): if tables_required[i] in ["dataset", "keyword", "dependency"]: continue j = j.join( self.db_connection.metadata["tables"][ - f"{schema}.{tables_required[i]}" + f"{schema_str}{tables_required[i]}" ] ) @@ -442,14 +444,14 @@ def find_datasets( if "keyword" in tables_required: j = j.join( self.db_connection.metadata["tables"][ - f"{schema}.dataset_keyword" + f"{schema_str}dataset_keyword" ] - ).join(self.db_connection.metadata["tables"][f"{schema}.keyword"]) + ).join(self.db_connection.metadata["tables"][f"{schema_str}keyword"]) # Special case for dependencies if "dependency" in tables_required: - dataset_table = self.db_connection.metadata["tables"][f"{schema}.dataset"] - dependency_table = self.db_connection.metadata["tables"][f"{schema}.dependency"] + dataset_table = self.db_connection.metadata["tables"][f"{schema_str}dataset"] + dependency_table = self.db_connection.metadata["tables"][f"{schema_str}dependency"] j = j.join( dependency_table, @@ -460,7 +462,7 @@ def find_datasets( else: stmt = stmt.select_from( self.db_connection.metadata["tables"][ - f"{schema}.{tables_required[0]}" + f"{schema_str}{tables_required[0]}" ] ) diff --git a/src/dataregistry/registrar/dataset.py b/src/dataregistry/registrar/dataset.py index 7172979..3f7b870 100644 --- a/src/dataregistry/registrar/dataset.py +++ b/src/dataregistry/registrar/dataset.py @@ -1016,9 +1016,10 @@ def add_keywords(self, dataset_id, keywords): ) result = conn.execute(stmt) + rows = result.fetchall() # If we don't have the keyword, add it - if result.rowcount == 0: + if len(rows) == 0: add_table_row( conn, dataset_keyword_table,