Skip to content

Commit

Permalink
Add duplicate_column_names list to db_connection to help with querying
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartmcalpine committed Dec 19, 2024
1 parent e37a038 commit 615a157
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 17 deletions.
31 changes: 31 additions & 0 deletions src/dataregistry/db_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataregistry import __version__
from dataregistry.exceptions import DataRegistryException
from dataregistry.schema import DEFAULT_SCHEMA_WORKING
from functools import cached_property

"""
Low-level utility routines and classes for accessing the registry
Expand Down Expand Up @@ -284,6 +285,36 @@ def _get_db_info(prov_table, get_associated_production=False):
# Store metadata
self.metadata["tables"] = metadata.tables

@cached_property
def duplicate_column_names(self):
"""
Probe the database for tables which share column names. This is used
later for querying.
Returns
-------
duplicates : list
List of column names that are duplicated across tables
"""

# Database hasn't been reflected yet
if len(self.metadata) == 0:
self._reflect()

# Find duplicate column names
duplicates = set()
all_columns = []
for table in self.metadata["tables"]:
for column in self.metadata["tables"][table].c:
if self.metadata["tables"][table].schema != self.active_schema:
continue

if column.name in all_columns:
duplicates.add(column.name)
all_columns.append(column.name)

return list(duplicates)

def get_table(self, tbl, schema=None):
"""
Get metadata for a specific table in the database.
Expand Down
30 changes: 13 additions & 17 deletions src/dataregistry/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,20 @@ def _parse_selected_columns(self, column_names):
input_parts = col_name.split(".")
num_parts = len(input_parts)

# Make sure column name is value
if num_parts > 2:
raise ValueError(f"{col_name} is not a valid column")

if num_parts == 1:
if col_name in self.db_connection.duplicate_column_names:
raise DataRegistryException(
(
f"Column name '{col_name}' is not unique to one table "
f"in the database, use <table_name>.<column_name> "
f"format instead"
)
)

# Loop over each column in the database and find matches
for table in self.db_connection.metadata["tables"]:
for column in self.db_connection.metadata["tables"][table].c:
Expand All @@ -216,30 +227,15 @@ def _parse_selected_columns(self, column_names):
# Input is in <column> format
if input_parts[0] == table_parts[-1]:
tmp_column_list[column.table.schema].append(column)
tables_required.add(column.table.name)
elif num_parts == 2:
# Input is in <table>.<column> format
if (
input_parts[0] == table_parts[-2]
and input_parts[1] == table_parts[-1]
):
tmp_column_list[column.table.schema].append(column)

# Make sure we don't find multiple matches
for s in tmp_column_list.keys(): # Each schema
chk = []
for x in tmp_column_list[s]: # Each column in schema
if x.name in chk:
raise DataRegistryException(
(
f"Column name '{col_name}' is not unique to one table "
f"in the database, use <table_name>.<column_name> "
f"format instead"
)
)
chk.append(x.name)

# Add this table to the list
tables_required.add(x.table.name)
tables_required.add(column.table.name)

# Store results
for att in tmp_column_list.keys():
Expand Down

0 comments on commit 615a157

Please sign in to comment.