Skip to content

Commit

Permalink
Fix SQL calls with wrong quotation for string literals.
Browse files Browse the repository at this point in the history
Many SQL statements were written as:
  SELECT ... FROM ... WHERE ... IN ("x","y",...)
and they should have been:
  SELECT ... FROM ... WHERE ... IN ('x','y',...)

Sqlite used to be fine with that, but the new version is more strict about it.
So we needed to update a ton of wrongly quoted strings.

See also https://www.sqlite.org/quirks.html#double_quoted_string_literals_are_accepted

Thanks Ana for spotting this!
  • Loading branch information
jordibc committed Mar 6, 2025
1 parent 327d0fb commit e04ae42
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 75 deletions.
41 changes: 20 additions & 21 deletions ete4/gtdb_taxonomy/gtdbquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
DEFAULT_GTDBTAXADB = ETE_DATA_HOME + '/gtdbtaxa.sqlite'
DEFAULT_GTDBTAXADUMP = ETE_DATA_HOME + '/gtdbdump.tar.gz'


def as_csv(xs):
"""Return sequence xs as comma-separated values, quoted properly for SQL."""
return ','.join("'%s'" % str(x).replace("'", "''") for x in xs)


def is_taxadb_up_to_date(dbfile=DEFAULT_GTDBTAXADB):
"""Check if a valid and up-to-date gtdbtaxa.sqlite database exists
If dbfile= is not specified, DEFAULT_TAXADB is assumed
Expand All @@ -36,9 +42,7 @@ def is_taxadb_up_to_date(dbfile=DEFAULT_GTDBTAXADB):

db.close()

if version != DB_VERSION:
return False
return True
return version == DB_VERSION


class GTDBTaxa:
Expand Down Expand Up @@ -94,7 +98,7 @@ def _connect(self):

def _translate_merged(self, all_taxids):
conv_all_taxids = set((list(map(int, all_taxids))))
cmd = 'select taxid_old, taxid_new FROM merged WHERE taxid_old IN (%s)' %','.join(map(str, all_taxids))
cmd = 'SELECT taxid_old, taxid_new FROM merged WHERE taxid_old IN (%s)' % as_csv(all_taxids)

result = self.db.execute(cmd)
conversion = {}
Expand Down Expand Up @@ -153,7 +157,7 @@ def _get_id2rank(self, internal_taxids):
Note: Numeric taxids are not recognized by the official GTDB taxonomy database, only for internal usage.
"""
ids = ','.join('"%s"' % v for v in set(internal_taxids) - {None, ''})
ids = as_csv(set(internal_taxids) - {None, ''})
result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids)
return {tax: spname for tax, spname in result.fetchall()}

Expand All @@ -169,7 +173,7 @@ def get_rank(self, taxids):
name2ids = self._get_name_translator(taxids)
overlap_ids = name2ids.values()
taxids = [item for sublist in overlap_ids for item in sublist]
ids = ','.join('"%s"' % v for v in set(taxids) - {None, ''})
ids = as_csv(set(taxids) - {None, ''})
result = self.db.execute('SELECT taxid, rank FROM species WHERE taxid IN (%s)' % ids)
for tax, rank in result.fetchall():
taxid2rank[list(self._get_taxid_translator([tax]).values())[0]] = rank
Expand All @@ -183,8 +187,7 @@ def _get_lineage_translator(self, taxids):
all_ids = set(taxids)
all_ids.discard(None)
all_ids.discard("")
query = ','.join(['"%s"' %v for v in all_ids])
result = self.db.execute('SELECT taxid, track FROM species WHERE taxid IN (%s);' %query)
result = self.db.execute('SELECT taxid, track FROM species WHERE taxid IN (%s);' % as_csv(all_ids))
id2lineages = {}
for tax, track in result.fetchall():
id2lineages[tax] = list(map(int, reversed(track.split(","))))
Expand Down Expand Up @@ -229,8 +232,7 @@ def _get_lineage(self, taxid):
return list(reversed(track))

def get_common_names(self, taxids):
query = ','.join(['"%s"' %v for v in taxids])
cmd = "select taxid, common FROM species WHERE taxid IN (%s);" %query
cmd = "select taxid, common FROM species WHERE taxid IN (%s);" % as_csv(taxids)
result = self.db.execute(cmd)
id2name = {}
for tax, common_name in result.fetchall():
Expand All @@ -246,8 +248,7 @@ def _get_taxid_translator(self, taxids, try_synonyms=True):
all_ids = set(map(int, taxids))
all_ids.discard(None)
all_ids.discard("")
query = ','.join(['"%s"' %v for v in all_ids])
cmd = "select taxid, spname FROM species WHERE taxid IN (%s);" %query
cmd = "select taxid, spname FROM species WHERE taxid IN (%s);" % as_csv(all_ids)
result = self.db.execute(cmd)
id2name = {}
for tax, spname in result.fetchall():
Expand Down Expand Up @@ -282,17 +283,15 @@ def _get_name_translator(self, names):

names = set(name2origname.keys())

query = ','.join(['"%s"' %n for n in name2origname.keys()])
cmd = 'select spname, taxid from species where spname IN (%s)' %query
result = self.db.execute('select spname, taxid from species where spname IN (%s)' %query)
cmd = 'SELECT spname, taxid FROM species WHERE spname IN (%s)' % as_csv(name2origname.keys())
result = self.db.execute(cmd)
for sp, taxid in result.fetchall():
oname = name2origname[sp.lower()]
name2id.setdefault(oname, []).append(taxid)
#name2realname[oname] = sp
missing = names - set([n.lower() for n in name2id.keys()])
if missing:
query = ','.join(['"%s"' %n for n in missing])
result = self.db.execute('select spname, taxid from synonym where spname IN (%s)' %query)
result = self.db.execute('SELECT spname, taxid FROM synonym WHERE spname IN (%s)' % as_csv(missing))
for sp, taxid in result.fetchall():
oname = name2origname[sp.lower()]
name2id.setdefault(oname, []).append(taxid)
Expand Down Expand Up @@ -882,12 +881,12 @@ def upload_data(dbfile):

descendants = gtdb.get_descendant_taxa('c__Thorarchaeia', collapse_subspecies=True, return_tree=True)
print(descendants.write(properties=None))
print(descendants.get_ascii(properties=['sci_name', 'taxid','rank']))
print(descendants.get_ascii(properties=['sci_name', 'taxid', 'rank']))
tree = gtdb.get_topology(["p__Huberarchaeota", "o__Peptococcales", "f__Korarchaeaceae", "s__Korarchaeum"], intermediate_nodes=True, collapse_subspecies=True, annotate=True)
print(tree.get_ascii(properties=["taxid", "sci_name", "rank"]))
print(tree.get_ascii(properties=["taxid", "sci_name", "rank"]))

tree = PhyloTree('((c__Thorarchaeia, c__Lokiarchaeia_A), s__Caballeronia udeis);', sp_naming_function=lambda name: name)
tax2name, tax2track, tax2rank = gtdb.annotate_tree(tree, taxid_attr="name")
print(tree.get_ascii(properties=["taxid","name", "sci_name", "rank"]))
print(tree.get_ascii(properties=["taxid", "name", "sci_name", "rank"]))

print(gtdb.get_name_lineage(['RS_GCF_006228565.1','GB_GCA_001515945.1']))
print(gtdb.get_name_lineage(['RS_GCF_006228565.1', 'GB_GCA_001515945.1']))
34 changes: 16 additions & 18 deletions ete4/ncbi_taxonomy/ncbiquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
DEFAULT_TAXDUMP = ETE_DATA_HOME + '/taxdump.tar.gz'


def as_csv(xs):
"""Return sequence xs as comma-separated values, quoted properly for SQL."""
return ','.join("'%s'" % str(x).replace("'", "''") for x in xs)


def is_taxadb_up_to_date(dbfile=DEFAULT_TAXADB):
"""Return True if a valid and up-to-date taxa.sqlite database exists.
Expand Down Expand Up @@ -97,7 +102,7 @@ def _translate_merged(self, all_taxids):

cmd = ('SELECT taxid_old, taxid_new '
'FROM merged WHERE taxid_old IN (%s)' %
','.join(map(str, all_taxids)))
as_csv(all_taxids))

result = self.db.execute(cmd)

Expand Down Expand Up @@ -128,7 +133,7 @@ def get_fuzzy_name_translation(self, name, sim=0.9):

print("Trying fuzzy search for %s" % name)
maxdiffs = math.ceil(len(name) * (1-sim))
cmd = (f'SELECT taxid, spname, LEVENSHTEIN(spname, "{name}") AS sim '
cmd = (f'SELECT taxid, spname, LEVENSHTEIN(spname, \'{name}\') AS sim '
f'FROM species WHERE sim <= {maxdiffs} ORDER BY sim LIMIT 1;')

taxid, spname, score = None, None, len(name)
Expand All @@ -137,7 +142,7 @@ def get_fuzzy_name_translation(self, name, sim=0.9):
taxid, spname, score = result.fetchone()
except TypeError:
cmd = (
f'SELECT taxid, spname, LEVENSHTEIN(spname, "{name}") AS sim '
f'SELECT taxid, spname, LEVENSHTEIN(spname, \'{name}\') AS sim '
f'FROM synonym WHERE sim <= {maxdiffs} ORDER BY sim LIMIT 1;')
result = _db.execute(cmd)
try:
Expand All @@ -161,8 +166,7 @@ def get_rank(self, taxids):
all_ids.discard(None)
all_ids.discard("")

query = ','.join('"%s"' % v for v in all_ids)
cmd = 'SELECT taxid, rank FROM species WHERE taxid IN (%s);' % query
cmd = 'SELECT taxid, rank FROM species WHERE taxid IN (%s);' % as_csv(all_ids)
result = self.db.execute(cmd)

id2rank = {}
Expand All @@ -180,8 +184,7 @@ def get_lineage_translator(self, taxids):
all_ids.discard(None)
all_ids.discard("")

query = ','.join('"%s"' % v for v in all_ids)
cmd = 'SELECT taxid, track FROM species WHERE taxid IN (%s);' % query
cmd = 'SELECT taxid, track FROM species WHERE taxid IN (%s);' % as_csv(all_ids)
result = self.db.execute(cmd)

id2lineages = {}
Expand Down Expand Up @@ -220,8 +223,7 @@ def get_lineage(self, taxid):
return list(reversed(track))

def get_common_names(self, taxids):
query = ','.join('"%s"' % v for v in taxids)
cmd = 'SELECT taxid, common FROM species WHERE taxid IN (%s);' % query
cmd = 'SELECT taxid, common FROM species WHERE taxid IN (%s);' % as_csv(taxids)
result = self.db.execute(cmd)

id2name = {}
Expand All @@ -237,8 +239,7 @@ def get_taxid_translator(self, taxids, try_synonyms=True):
all_ids.discard(None)
all_ids.discard("")

query = ','.join('"%s"' % v for v in all_ids)
cmd = 'SELECT taxid, spname FROM species WHERE taxid IN (%s);' % query
cmd = 'SELECT taxid, spname FROM species WHERE taxid IN (%s);' % as_csv(all_ids)
result = self.db.execute(cmd)

id2name = {}
Expand All @@ -252,8 +253,7 @@ def get_taxid_translator(self, taxids, try_synonyms=True):
new2old = {v: k for k,v in old2new.items()}

if old2new:
query = ','.join('"%s"' % v for v in new2old)
cmd = 'SELECT taxid, spname FROM species WHERE taxid IN (%s);' % query
cmd = 'SELECT taxid, spname FROM species WHERE taxid IN (%s);' % as_csv(new2old)
result = self.db.execute(cmd)
for tax, spname in result.fetchall():
id2name[new2old[tax]] = spname
Expand All @@ -273,18 +273,16 @@ def get_name_translator(self, names):

names = set(name2origname.keys())

query = ','.join('"%s"' % n for n in name2origname.keys())
cmd = 'SELECT spname, taxid FROM species WHERE spname IN (%s)' % query
result = self.db.execute('SELECT spname, taxid FROM species WHERE spname IN (%s)' % query)
cmd = 'SELECT spname, taxid FROM species WHERE spname IN (%s)' % as_csv(name2origname.keys())
result = self.db.execute(cmd)
for sp, taxid in result.fetchall():
oname = name2origname[sp.lower()]
name2id.setdefault(oname, []).append(taxid)
#name2realname[oname] = sp
missing = names - set([n.lower() for n in name2id.keys()])
if missing:
query = ','.join('"%s"' % n for n in missing)
result = self.db.execute('SELECT spname, taxid FROM synonym '
'WHERE spname IN (%s)' % query)
'WHERE spname IN (%s)' % as_csv(missing))
for sp, taxid in result.fetchall():
oname = name2origname[sp.lower()]
name2id.setdefault(oname, []).append(taxid)
Expand Down
Loading

0 comments on commit e04ae42

Please sign in to comment.