diff --git a/big_scape/genbank/gbk.py b/big_scape/genbank/gbk.py index a5285e56..07fd85d9 100644 --- a/big_scape/genbank/gbk.py +++ b/big_scape/genbank/gbk.py @@ -6,6 +6,8 @@ # from enum import Enum from pathlib import Path +import random +import string from typing import Dict, Optional import hashlib @@ -14,6 +16,7 @@ from Bio import SeqIO from Bio.SeqRecord import SeqRecord from Bio.SeqFeature import SeqFeature +from sqlalchemy import Column, ForeignKey, Integer, String, Table, select # from other modules from big_scape.errors import InvalidGBKError @@ -34,6 +37,63 @@ # MIBIG = "mibig" # REFERENCE = "reference" +# TODO: generalize creating temp tables. this is copied from network.py + + +def create_temp_hash_table(gbks: list[GBK]) -> Table: + """Create a temporary table with ids of given records + + Args: + include_records (list[BGCRecord]): the records to include in the connected component + + Returns: + Table: the temporary table + """ + + # generate a short random string + temp_table_name = "temp_" + "".join(random.choices(string.ascii_lowercase, k=10)) + + temp_table = Table( + temp_table_name, + DB.metadata, + Column( + "hash", + String, + ForeignKey(DB.metadata.tables["gbk"].c.hash), + primary_key=True, + nullable=False, + ), + prefixes=["TEMPORARY"], + ) + + DB.metadata.create_all(DB.engine) + + if DB.engine is None: + raise RuntimeError("DB engine is None") + + cursor = DB.engine.raw_connection().driver_connection.cursor() + + insert_query = f""" + INSERT INTO {temp_table_name} (hash) VALUES (?); + """ + + def batch_hash(gbks: list[GBK], n: int): + l = len(gbks) + for ndx in range(0, l, n): + yield [gbk.hash for gbk in gbks[ndx : min(ndx + n, l)]] + + for hash_batch in batch_hash(gbks, 1000): + cursor.executemany(insert_query, [(x,) for x in hash_batch]) # type: ignore + + cursor.close() + + DB.commit() + + if DB.metadata is None: + raise ValueError("DB metadata is None") + + return temp_table + class GBK: """ @@ -261,7 +321,7 @@ def load_many(input_gbks: list[GBK]) -> list[GBK]: list[GBK]: loaded GBK objects """ - input_gbk_hashes = [gbk.hash for gbk in input_gbks] + temp_hash_table = create_temp_hash_table(input_gbks) if not DB.metadata: raise RuntimeError("DB.metadata is None") @@ -278,7 +338,7 @@ def load_many(input_gbks: list[GBK]) -> list[GBK]: gbk_table.c.taxonomy, gbk_table.c.description, ) - .where(gbk_table.c.hash.in_(input_gbk_hashes)) + .where(gbk_table.c.hash.in_(select(temp_hash_table.c.hash))) .compile() ) @@ -616,15 +676,15 @@ def collapse_hybrids_in_cand_clusters( for number in cand_cluster.proto_clusters.keys() ] merged_protocluster = MergedProtoCluster.merge(protoclusters) - merged_tmp_proto_clusters[ - merged_protocluster.number - ] = merged_protocluster + merged_tmp_proto_clusters[merged_protocluster.number] = ( + merged_protocluster + ) # update the protocluster old:new ids for the merged protoclusters of this cand_cluster for proto_cluster_num in cand_cluster.proto_clusters.keys(): - merged_protocluster_ids[ - proto_cluster_num - ] = merged_protocluster.number + merged_protocluster_ids[proto_cluster_num] = ( + merged_protocluster.number + ) # now we build a new version of the tmp_proto_clusters dict that contains the merged protoclusters # as well as protoclusters which did not need merging, with updated unique IDs/numbers @@ -638,9 +698,9 @@ def collapse_hybrids_in_cand_clusters( # this protocluster has been merged, so we need to add it to # the dict with its new protocluster number new_proto_cluster_num = merged_protocluster_ids[proto_cluster_num] - updated_tmp_proto_clusters[ - new_proto_cluster_num - ] = merged_tmp_proto_clusters[new_proto_cluster_num] + updated_tmp_proto_clusters[new_proto_cluster_num] = ( + merged_tmp_proto_clusters[new_proto_cluster_num] + ) updated_proto_cluster_dict[new_proto_cluster_num] = None else: # protoclusters which have not been merged are added to the dict as is