Skip to content

Commit

Permalink
refactor: remove Column and Join from SimbadClass
Browse files Browse the repository at this point in the history
  • Loading branch information
ManonMarchand committed Jun 20, 2024
1 parent 1cfbadb commit 7c4f0df
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 93 deletions.
119 changes: 59 additions & 60 deletions astroquery/simbad/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ def _cached_query_tap(tap, query: str, *, maxrec=10000):
return tap.search(query, maxrec=maxrec).to_table()


@dataclass(frozen=True)
class _Column:
"""A class to define a column in a SIMBAD query."""
table: str
name: str
alias: str = field(default=None)


@dataclass(frozen=True)
class _Join:
"""A class to define a join between two tables."""
table: str
column_left: Any
column_right: Any
join_type: str = field(default="JOIN")


class SimbadClass(BaseVOQuery):
"""The class for querying the SIMBAD web service.
Expand All @@ -84,30 +101,15 @@ class SimbadClass(BaseVOQuery):
"""
SIMBAD_URL = 'https://' + conf.server + '/simbad/sim-script'

@dataclass(frozen=True)
class Column:
"""A class to define a column in a SIMBAD query."""
table: str
name: str
alias: str = field(default=None)

@dataclass(frozen=True)
class Join:
"""A class to define a join between two tables."""
table: str
column_left: Any
column_right: Any
join_type: str = field(default="JOIN")

def __init__(self, ROW_LIMIT=None):
super().__init__()
# to create the TAPService
self._server = conf.server
self._tap = None
self._hardlimit = None
# attributes to construct ADQL queries
self._columns_in_output = None # a list of Simbad.Column
self.joins = [] # a list of Simbad.Join
self._columns_in_output = None # a list of _Column
self.joins = [] # a list of _Join
self.criteria = [] # a list of strings
self.ROW_LIMIT = ROW_LIMIT

Expand Down Expand Up @@ -165,7 +167,7 @@ def hardlimit(self):

@property
def columns_in_output(self):
"""A list of Simbad.Column.
"""A list of _Column.
They will be included in the output of the following methods:
Expand All @@ -178,7 +180,7 @@ def columns_in_output(self):
"""
if self._columns_in_output is None:
self._columns_in_output = [Simbad.Column("basic", item)
self._columns_in_output = [_Column("basic", item)
for item in conf.default_columns]
return self._columns_in_output

Expand Down Expand Up @@ -277,7 +279,7 @@ def _get_bundle_columns(self, bundle_name):
Returns
-------
list[Simbad.Column]
list[simbad._Column]
The list of columns corresponding to the selected bundle.
"""
basic_columns = set(map(str.casefold, set(self.list_columns("basic")["column_name"])))
Expand All @@ -287,10 +289,10 @@ def _get_bundle_columns(self, bundle_name):

if bundle_name in bundle_entries:
bundle = bundle_entries[bundle_name]
columns = [Simbad.Column("basic", column) for column in basic_columns
columns = [_Column("basic", column) for column in basic_columns
if column.startswith(bundle["tap_startswith"])]
if "tap_column" in bundle:
columns = [Simbad.Column("basic", column) for column in bundle["tap_column"]] + columns
columns = [_Column("basic", column) for column in bundle["tap_column"]] + columns
return columns

def _add_table_to_output(self, table):
Expand All @@ -308,7 +310,7 @@ def _add_table_to_output(self, table):
table = table.casefold()

if table == "basic":
self.columns_in_output.append(Simbad.Column(table, "*"))
self.columns_in_output.append(_Column(table, "*"))
return

linked_to_basic = self.list_linked_tables("basic")
Expand All @@ -329,10 +331,10 @@ def _add_table_to_output(self, table):
alias = [f'"{table}.{column}"' if not column.startswith(table) else None for column in columns]

# modify the attributes here
self.columns_in_output += [Simbad.Column(table, column, alias)
self.columns_in_output += [_Column(table, column, alias)
for column, alias in zip(columns, alias)]
self.joins += [Simbad.Join(table, Simbad.Column("basic", link["target_column"]),
Simbad.Column(table, link["from_column"]))]
self.joins += [_Join(table, _Column("basic", link["target_column"]),
_Column(table, link["from_column"]))]

def add_votable_fields(self, *args):
"""Add columns to the output of a SIMBAD query.
Expand Down Expand Up @@ -360,8 +362,8 @@ def add_votable_fields(self, *args):
>>> from astroquery.simbad import Simbad
>>> simbad = Simbad()
>>> simbad.add_votable_fields('sp_type', 'sp_qual', 'sp_bibcode') # doctest: +REMOTE_DATA
>>> simbad.columns_in_output[0] # doctest: +REMOTE_DATA
SimbadClass.Column(table='basic', name='main_id', alias=None)
>>> simbad.get_votable_fields() # doctest: +REMOTE_DATA
['basic.main_id', 'basic.ra', 'basic.dec', 'basic.coo_err_maj', 'basic.coo_err_min', ...
"""

# the legacy way of adding fluxes is the only case-dependant option
Expand All @@ -375,9 +377,9 @@ def add_votable_fields(self, *args):
flux_filter = re.findall(r"\((\w+)\)", arg)[0]
if len(flux_filter) == 1 and flux_filter.islower():
flux_filter = flux_filter + "_"
self.joins.append(self.Join("allfluxes", self.Column("basic", "oid"),
self.Column("allfluxes", "oidref")))
self.columns_in_output.append(self.Column("allfluxes", flux_filter))
self.joins.append(_Join("allfluxes", _Column("basic", "oid"),
_Column("allfluxes", "oidref")))
self.columns_in_output.append(_Column("allfluxes", flux_filter))
args.remove(arg)

# casefold args
Expand All @@ -391,7 +393,7 @@ def add_votable_fields(self, *args):
bundles = output_options[output_options["type"] == "bundle of basic columns"]["name"]

# Add columns from basic
self.columns_in_output += [Simbad.Column("basic", column) for column in args if column in basic_columns]
self.columns_in_output += [_Column("basic", column) for column in args if column in basic_columns]

# Add tables
tables_to_add = [table for table in args if table in all_tables]
Expand All @@ -415,7 +417,7 @@ def add_votable_fields(self, *args):
# some columns are still there but under a new name
if field_type == "alias":
tap_column = field_data["tap_column"]
self.columns_in_output.append(Simbad.Column("basic", tap_column))
self.columns_in_output.append(_Column("basic", tap_column))
warning_message = (f"'{votable_field}' has been renamed '{tap_column}'. You'll see it "
"appearing with its new name in the output table")
warnings.warn(warning_message, DeprecationWarning, stacklevel=2)
Expand Down Expand Up @@ -462,7 +464,7 @@ def reset_votable_fields(self):
- `query_criteria`.
"""
self.columns_in_output = [Simbad.Column("basic", item)
self.columns_in_output = [_Column("basic", item)
for item in conf.default_columns]
self.joins = []
self.criteria = []
Expand Down Expand Up @@ -555,9 +557,9 @@ def query_object(self, object_name, *, wildcard=False,
"""
top, columns, joins, instance_criteria = self._get_query_parameters()

columns.append(Simbad.Column("ident", "id", "matched_id"))
columns.append(_Column("ident", "id", "matched_id"))

joins.append(Simbad.Join("ident", Simbad.Column("basic", "oid"), Simbad.Column("ident", "oidref")))
joins.append(_Join("ident", _Column("basic", "oid"), _Column("ident", "oidref")))

if wildcard:
instance_criteria.append(rf" regexp(id, '{_wildcard_to_regexp(object_name)}') = 1")
Expand Down Expand Up @@ -626,10 +628,10 @@ def query_objects(self, object_names, *, wildcard=False, criteria=None,
instance_criteria.append(f"({criteria})")

if wildcard:
columns.append(Simbad.Column("ident", "id", "matched_id"))
joins += [Simbad.Join("ident", Simbad.Column("basic", "oid"),
Simbad.Column("ident", "oidref"))]
list_criteria = [f"regexp(id, '{_wildcard_to_regexp(object_name)}') = 1" for object_name in object_names]
columns.append(_Column("ident", "id", "matched_id"))
joins += [_Join("ident", _Column("basic", "oid"), _Column("ident", "oidref"))]
list_criteria = [f"regexp(id, '{_wildcard_to_regexp(object_name)}') = 1"
for object_name in object_names]
instance_criteria += [f'({" OR ".join(list_criteria)})']

return self._query(top, columns, joins, instance_criteria,
Expand All @@ -640,16 +642,15 @@ def query_objects(self, object_names, *, wildcard=False, criteria=None,
upload = Table({"user_specified_id": object_names,
"object_number_id": list(range(1, len(object_names) + 1))})
upload_name = "TAP_UPLOAD.script_infos"
columns.append(Simbad.Column(upload_name, "*"))
columns.append(_Column(upload_name, "*"))

left_joins = [Simbad.Join("ident", Simbad.Column(upload_name, "user_specified_id"),
Simbad.Column("ident", "id"), "LEFT JOIN"),
Simbad.Join("basic", Simbad.Column("basic", "oid"),
Simbad.Column("ident", "oidref"), "LEFT JOIN")]
left_joins = [_Join("ident", _Column(upload_name, "user_specified_id"),
_Column("ident", "id"), "LEFT JOIN"),
_Join("basic", _Column("basic", "oid"),
_Column("ident", "oidref"), "LEFT JOIN")]
for join in joins:
left_joins.append(Simbad.Join(join.table,
join.column_left,
join.column_right, "LEFT JOIN"))
left_joins.append(_Join(join.table, join.column_left,
join.column_right, "LEFT JOIN"))
return self._query(top, columns, left_joins, instance_criteria,
from_table=upload_name,
get_query_payload=get_query_payload,
Expand Down Expand Up @@ -814,9 +815,9 @@ def query_catalog(self, catalog, *, criteria=None, get_query_payload=False,
"""
top, columns, joins, instance_criteria = self._get_query_parameters()

columns.append(Simbad.Column("ident", "id", "catalog_id"))
columns.append(_Column("ident", "id", "catalog_id"))

joins += [Simbad.Join("ident", Simbad.Column("basic", "oid"), Simbad.Column("ident", "oidref"))]
joins += [_Join("ident", _Column("basic", "oid"), _Column("ident", "oidref"))]

instance_criteria.append(fr"id LIKE '{catalog} %'")
if criteria:
Expand Down Expand Up @@ -848,13 +849,11 @@ def query_bibobj(self, bibcode, *, criteria=None,
"""
top, columns, joins, instance_criteria = self._get_query_parameters()

joins += [Simbad.Join("has_ref", Simbad.Column("basic", "oid"),
Simbad.Column("has_ref", "oidref")),
Simbad.Join("ref", Simbad.Column("has_ref", "oidbibref"),
Simbad.Column("ref", "oidbib"))]
joins += [_Join("has_ref", _Column("basic", "oid"), _Column("has_ref", "oidref")),
_Join("ref", _Column("has_ref", "oidbibref"), _Column("ref", "oidbib"))]

columns += [Simbad.Column("ref", "bibcode"),
Simbad.Column("has_ref", "obj_freq")]
columns += [_Column("ref", "bibcode"),
_Column("has_ref", "obj_freq")]

instance_criteria.append(f"bibcode = '{_adql_parameter(bibcode)}'")
if criteria:
Expand Down Expand Up @@ -1071,11 +1070,11 @@ def query_criteria(self, *args, get_query_payload=False, **kwargs):
added_criteria = f"({CriteriaTranslator.parse(' & '.join(list(list(args) + list_kwargs)))})"
instance_criteria.append(added_criteria)
if "otypes." in added_criteria:
joins.append(self.Join("otypes", self.Column("basic", "oid"),
self.Column("otypes", "oidref")))
joins.append(_Join("otypes", _Column("basic", "oid"),
_Column("otypes", "oidref")))
if "allfluxes." in added_criteria:
joins.append(self.Join("allfluxes", self.Column("basic", "oid"),
self.Column("allfluxes", "oidref")))
joins.append(_Join("allfluxes", _Column("basic", "oid"),
_Column("allfluxes", "oidref")))
return self._query(top, columns, joins, instance_criteria,
get_query_payload=get_query_payload)

Expand Down
Loading

0 comments on commit 7c4f0df

Please sign in to comment.