Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/classify query #71

Merged
merged 16 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions big_scape/cli/benchmark_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def benchmark(ctx, *args, **kwargs):
"""
# get context parameters
ctx.obj.update(ctx.params)
ctx.obj["mode"] = "Benchmark"

# workflow validations
validate_output_paths(ctx)
Expand Down
16 changes: 0 additions & 16 deletions big_scape/cli/cli_common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,22 +252,6 @@ def common_cluster_query(fn):
"the listed accessions will be analysed."
),
),
# comparison parameters
click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others."
"Will also use BiG-SCAPEv1 legacy_weights for distance calculations."
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use"
" at your own risk, as BGC classes may have changed. All antiSMASH"
"classes that this legacy mode does not recognize will be grouped in"
" 'others'."
),
),
click.option(
"--legacy_weights",
is_flag=True,
Expand Down
17 changes: 17 additions & 0 deletions big_scape/cli/cluster_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@
"all input BGCs to the query in a one-vs-all mode."
),
)
# comparison parameters
@click.option(
"--legacy_classify",
is_flag=True,
help=(
"Does not use antiSMASH BGC classes to run analyses on "
"class-based bins, instead it uses BiG-SCAPE v1 predefined groups: "
"PKS1, PKSOther, NRPS, NRPS-PKS-hybrid, RiPP, Saccharide, Terpene, Others."
"Will also use BiG-SCAPEv1 legacy_weights for distance calculations."
"This feature is available for backwards compatibility with "
"antiSMASH versions up to v7. For higher antiSMASH versions, use"
" at your own risk, as BGC classes may have changed. All antiSMASH"
"classes that this legacy mode does not recognize will be grouped in"
" 'others'."
),
)
# binning parameters
@click.option("--no_mix", is_flag=True, help=("Dont run the all-vs-all analysis"))
# networking parameters
Expand All @@ -52,6 +68,7 @@ def cluster(ctx, *args, **kwargs):
# get context parameters
ctx.obj.update(ctx.params)
ctx.obj["query_bgc_path"] = None
ctx.obj["mode"] = "Cluster"

# workflow validations
validate_binning_cluster_workflow(ctx)
Expand Down
2 changes: 2 additions & 0 deletions big_scape/cli/query_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def query(ctx, *args, **kwarg):
# get context parameters
ctx.obj.update(ctx.params)
ctx.obj["no_mix"] = None
ctx.obj["legacy_classify"] = False
ctx.obj["mode"] = "Query"

# workflow validations
validate_skip_hmmscan(ctx)
Expand Down
2 changes: 2 additions & 0 deletions big_scape/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
QueryToRefRecordPairGenerator,
RefToRefRecordPairGenerator,
MissingRecordPairGenerator,
ConnectedComponenetPairGenerator,
generate_mix,
legacy_bin_generator,
legacy_get_class,
Expand All @@ -24,6 +25,7 @@
"QueryToRefRecordPairGenerator",
"RefToRefRecordPairGenerator",
"MissingRecordPairGenerator",
"ConnectedComponenetPairGenerator",
"generate_mix",
"ComparableRegion",
"generate_edges",
Expand Down
127 changes: 118 additions & 9 deletions big_scape/comparison/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,40 @@ def add_records(self, record_list: list[BGCRecord]):
if None in self.record_ids:
raise ValueError("Region in bin has no db id!")

def cull_singletons(self, cutoff: float):
"""Culls singletons for given cutoff, i.e. records which have either no edges
in the database, or all edges have a distance above/equal to the cutoff"""

if not DB.metadata:
raise RuntimeError("DB.metadata is None")

distance_table = DB.metadata.tables["distance"]

# get all distances in the table below the cutoff
select_statement = (
select(distance_table.c.region_a_id, distance_table.c.region_b_id)
.where(
distance_table.c.region_a_id.in_(self.record_ids)
| distance_table.c.region_b_id.in_(self.record_ids)
)
.where(distance_table.c.distance < cutoff)
.where(distance_table.c.weights == self.weights)
)

edges = DB.execute(select_statement).fetchall()

# get all record_ids in the edges
filtered_record_ids = set()
for edge in edges:
filtered_record_ids.update(edge)

self.record_ids = filtered_record_ids
self.source_records = [
record
for record in self.source_records
if record._db_id in filtered_record_ids
]

def __repr__(self) -> str:
return (
f"Bin '{self.label}': {self.num_pairs()} pairs from "
Expand All @@ -141,8 +175,8 @@ class QueryToRefRecordPairGenerator(RecordPairGenerator):
ref <-> ref pairs
"""

def __init__(self, label: str):
super().__init__(label)
def __init__(self, label: str, weights: Optional[str] = None):
super().__init__(label, weights)
self.reference_records: list[BGCRecord] = []
self.query_records: list[BGCRecord] = []

Expand Down Expand Up @@ -225,10 +259,11 @@ class RefToRefRecordPairGenerator(RecordPairGenerator):
source_records (list[BGCRecord]): List of BGC records to generate pairs from
"""

def __init__(self, label: str):
def __init__(self, label: str, weights: Optional[str] = None):
self.record_id_to_obj: dict[int, BGCRecord] = {}
self.reference_record_ids: set[int] = set()
self.done_record_ids: set[int] = set()
super().__init__(label)
super().__init__(label, weights)

def generate_pairs(self, legacy_sorting=False) -> Generator[RecordPair, None, None]:
"""Returns an Generator for Region pairs in this bin, pairs are only generated between
Expand Down Expand Up @@ -288,6 +323,9 @@ def add_records(self, record_list: list[BGCRecord]):
raise ValueError("Region in bin has no db id!")

self.record_id_to_obj[record._db_id] = record
if record.parent_gbk is not None:
if record.parent_gbk.source_type == SOURCE_TYPE.REFERENCE:
self.reference_record_ids.add(record._db_id)

return super().add_records(record_list)

Expand Down Expand Up @@ -315,16 +353,18 @@ def get_connected_reference_nodes(self) -> set[BGCRecord]:
select(distance_table.c.region_a_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
bgc_record_table.c.id.in_(
select(distance_table.c.region_b_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
)
)
.where(bgc_record_table.c.id.notin_(self.done_record_ids))
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand Down Expand Up @@ -364,16 +404,18 @@ def get_connected_reference_node_count(self) -> int:
select(distance_table.c.region_a_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
bgc_record_table.c.id.in_(
select(distance_table.c.region_b_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
),
)
)
.where(bgc_record_table.c.id.notin_(self.done_record_ids))
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand Down Expand Up @@ -406,16 +448,18 @@ def get_singleton_reference_nodes(self) -> set[BGCRecord]:
select(distance_table.c.region_a_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
)
)
.where(
bgc_record_table.c.id.notin_(
select(distance_table.c.region_b_id)
.distinct()
.where(distance_table.c.distance < 1.0)
.where(distance_table.c.weights == self.weights)
)
)
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand Down Expand Up @@ -463,7 +507,7 @@ def get_singleton_reference_node_count(self) -> int:
.where(distance_table.c.distance < 1.0)
)
)
.where(gbk_table.c.source_type == SOURCE_TYPE.REFERENCE.value)
.where(bgc_record_table.c.id.in_(self.reference_record_ids))
.join(gbk_table, bgc_record_table.c.gbk_id == gbk_table.c.id)
)

Expand All @@ -472,6 +516,64 @@ def get_singleton_reference_node_count(self) -> int:
return singleton_reference_node_count


class ConnectedComponenetPairGenerator(RecordPairGenerator):
"""Generator that takes as input a conected component and generates
all pairs from the nodes in the component"""

def __init__(self, connected_component, label: str):
super().__init__(label)
self.connected_component = connected_component
self.record_id_to_obj: dict[int, BGCRecord] = {}

def add_records(self, record_list: list[BGCRecord]):
"""Adds BGC records to this bin and creates a generator for the pairs

also creates a dictionary of record id to record objects
"""
cc_record_ids = set()
cc_record_list = []

for edge in self.connected_component:
record_a_id, record_b_id, dist, jacc, adj, dss, weights = edge
# Ensure that the correct weights are used,
# the weights are set during the binning process
self.weights = weights
cc_record_ids.add(record_a_id)
cc_record_ids.add(record_b_id)

for record in record_list:
if record._db_id is None:
raise ValueError("Region in bin has no db id!")
if record._db_id not in cc_record_ids:
continue

self.record_id_to_obj[record._db_id] = record
cc_record_list.append(record)

return super().add_records(cc_record_list)

def generate_pairs(self, legacy_sorting=False) -> Generator[RecordPair, None, None]:
"""Returns a Generator for all pairs in this bin"""

for edge in self.connected_component:
record_a_id, record_b_id, dist, jacc, adj, dss, weights = edge
if self.weights != weights:
logging.error(
"Edge in connected component does not have the same weight as the bin!"
)

record_a = self.record_id_to_obj[record_a_id]
record_b = self.record_id_to_obj[record_b_id]

if legacy_sorting:
sorted_a, sorted_b = sorted((record_a, record_b), key=sort_name_key)
pair = RecordPair(sorted_a, sorted_b)
else:
pair = RecordPair(record_a, record_b)

yield pair


class MissingRecordPairGenerator(RecordPairGenerator):
"""Generator that wraps around another RecordPairGenerator to exclude any distances
already in the database
Expand Down Expand Up @@ -513,14 +615,21 @@ def generate_pairs(self, legacy_sorting=False) -> Generator[RecordPair, None, No
select(distance_table.c.region_a_id, distance_table.c.region_b_id)
.where(distance_table.c.region_a_id.in_(self.bin.record_ids))
.where(distance_table.c.region_b_id.in_(self.bin.record_ids))
.where(distance_table.c.weights == self.bin.weights)
)

# generate a set of tuples of region id pairs
existing_distances = set(DB.execute(select_statement).fetchall())

for pair in self.bin.generate_pairs(legacy_sorting):
# if the pair is not in the set of existing distances, yield it
if (pair.region_a._db_id, pair.region_b._db_id) not in existing_distances:
if (
pair.region_a._db_id,
pair.region_b._db_id,
) not in existing_distances and (
pair.region_a._db_id,
pair.region_b._db_id,
) not in existing_distances:
yield pair

def add_records(self, _: list[BGCRecord]):
Expand Down
3 changes: 1 addition & 2 deletions big_scape/data/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
CREATE TABLE IF NOT EXISTS gbk (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT,
source_type TEXT,
nt_seq TEXT,
UNIQUE(path)
);
Expand Down Expand Up @@ -76,7 +75,7 @@ CREATE TABLE IF NOT EXISTS distance (
adjacency REAL NOT NULL,
dss REAL NOT NULL,
weights TEXT NOT NULL,
UNIQUE(region_a_id, region_b_id)
UNIQUE(region_a_id, region_b_id, weights)
FOREIGN KEY(region_a_id) REFERENCES bgc_record(id)
FOREIGN KEY(region_b_id) REFERENCES bgc_record(id)
);
Loading
Loading