diff --git a/bam_filter/__main__.py b/bam_filter/__main__.py
index c5dc21e..cd5f986 100644
--- a/bam_filter/__main__.py
+++ b/bam_filter/__main__.py
@@ -1,3 +1,4 @@
+#
"""
This program is free software: you can redistribute it and/or modify it under the terms of the GNU
General Public License as published by the Free Software Foundation, either version 3 of the
@@ -11,271 +12,35 @@
see .
"""
-
import logging
-import pandas as pd
-from bam_filter.sam_utils import process_bam, filter_reference_BAM, check_bam_file
+
+from bam_filter.reassign import reassign
+from bam_filter.filter import filter_references
+from bam_filter.lca import do_lca
from bam_filter.utils import (
get_arguments,
- create_output_files,
- concat_df,
)
-from bam_filter.entropy import find_knee
-import json
-import warnings
-from collections import Counter
-from functools import reduce
-import os
-import tempfile
log = logging.getLogger("my_logger")
-def handle_warning(message, category, filename, lineno, file=None, line=None):
- print("A warning occurred:")
- print(message)
- print("Do you wish to continue?")
-
- while True:
- response = input("y/n: ").lower()
- if response not in {"y", "n"}:
- print("Not understood.")
- else:
- break
-
- if response == "n":
- raise category(message)
-
-
-def obj_dict(obj):
- return obj.__dict__
-
-
-def get_summary(obj):
- return obj.to_summary()
-
-
-def get_lens(obj):
- return obj.get_read_length_freqs()
-
-
-# Check if the temporary directory exists, if not, create it
-def check_tmp_dir_exists(tmpdir):
- if tmpdir is None:
- tmpdir = tempfile.TemporaryDirectory(dir=os.getcwd())
- else:
- if not os.path.exists(tmpdir):
- log.error(f"Temporary directory {tmpdir} does not exist")
- exit(1)
- tmpdir = tempfile.TemporaryDirectory(dir=os.path.abspath(tmpdir))
- return tmpdir
-
-
def main():
logging.basicConfig(
level=logging.DEBUG,
format="%(levelname)s ::: %(asctime)s ::: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
-
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
args = get_arguments()
-
- tmp_dir = check_tmp_dir_exists(args.tmp_dir)
- log.info("Temporary directory: %s", tmp_dir.name)
- if args.trim_min >= args.trim_max:
- log.error("trim_min must be less than trim_max")
- exit(1)
-
logging.getLogger("my_logger").setLevel(
logging.DEBUG if args.debug else logging.INFO
)
-
- if args.debug:
- warnings.showwarning = handle_warning
- else:
- warnings.filterwarnings("ignore")
-
- plot_coverage = False
- if args.coverage_plots is not None:
- plot_coverage = True
-
- if args.stats_filtered is None and (args.bam_filtered is not None):
- logging.error(
- "You need to specify a filtereds stats file to obtain the filtered BAM file"
- )
- exit(1)
-
- out_files = create_output_files(
- prefix=args.prefix,
- bam=args.bam,
- stats=args.stats,
- stats_filtered=args.stats_filtered,
- bam_filtered=args.bam_filtered,
- read_length_freqs=args.read_length_freqs,
- read_hits_count=args.read_hits_count,
- knee_plot=args.knee_plot,
- coverage_plots=args.coverage_plots,
- tmp_dir=tmp_dir,
- )
-
- bam = check_bam_file(
- bam=args.bam,
- threads=args.threads,
- reference_lengths=args.reference_lengths,
- sort_memory=args.sort_memory,
- )
-
- data = process_bam(
- bam=bam,
- threads=args.threads,
- reference_lengths=args.reference_lengths,
- min_read_count=args.min_read_count,
- min_read_ani=args.min_read_ani,
- trim_ends=args.trim_ends,
- trim_min=args.trim_min,
- trim_max=args.trim_max,
- scale=args.scale,
- plot=plot_coverage,
- plots_dir=out_files["coverage_plot_dir"],
- chunksize=args.chunk_size,
- read_length_freqs=args.read_length_freqs,
- output_files=out_files,
- sort_memory=args.sort_memory,
- low_memory=args.low_memory,
- )
- if args.low_memory:
- bam = out_files["bam_tmp_sorted"]
-
- logging.info("Reducing results to a single dataframe")
- # data = list(filter(None, data))
- data_df = [x[0] for x in data if x[0] is not None]
- data_df = concat_df(data_df)
-
- if args.read_length_freqs is not None:
- logging.info("Calculating read length frequencies...")
- lens = [x[1] for x in data if x[1] is not None]
- lens = json.dumps(lens, default=obj_dict, ensure_ascii=False, indent=4)
- with open(out_files["read_length_freqs"], "w", encoding="utf-8") as outfile:
- print(lens, file=outfile)
-
- if args.read_hits_count is not None:
- logging.info("Calculating read hits counts...")
- hits = [x[2] for x in data if x[2] is not None]
-
- # merge dicts and sum values
- hits = reduce(lambda x, y: x.update(y) or x, (Counter(dict(x)) for x in hits))
- # hits = sum(map(Counter, hits), Counter())
-
- # convert dict to dataframe
- hits = (
- pd.DataFrame.from_dict(hits, orient="index", columns=["count"])
- .rename_axis("read")
- .reset_index()
- .sort_values(by="count", ascending=False)
- )
-
- hits.to_csv(
- out_files["read_hits_count"], sep="\t", index=False, compression="gzip"
- )
-
- logging.info(f"Writing reference statistics to {out_files['stats']}")
- data_df.to_csv(out_files["stats"], sep="\t", index=False, compression="gzip")
-
- if args.min_norm_entropy is None or args.min_norm_gini is None:
- filter_conditions = {
- "min_read_length": args.min_read_length,
- "min_read_count": args.min_read_count,
- "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
- "min_breadth": args.min_breadth,
- "min_avg_read_ani": args.min_avg_read_ani,
- "min_coverage_evenness": args.min_coverage_evenness,
- "min_coeff_var": args.min_coeff_var,
- "min_coverage_mean": args.min_coverage_mean,
- }
- elif args.min_norm_entropy == "auto" or args.min_norm_gini == "auto":
- if data_df.shape[0] > 1:
- min_norm_gini, min_norm_entropy = find_knee(
- data_df, out_plot_name=out_files["knee_plot"]
- )
-
- if min_norm_gini is None or min_norm_entropy is None:
- logging.warning(
- "Could not find knee in entropy plot. Disabling filtering by entropy/gini inequality."
- )
- filter_conditions = {
- "min_read_length": args.min_read_length,
- "min_read_count": args.min_read_count,
- "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
- "min_breadth": args.min_breadth,
- "min_avg_read_ani": args.min_avg_read_ani,
- "min_coverage_evenness": args.min_coverage_evenness,
- "min_coeff_var": args.min_coeff_var,
- "min_coverage_mean": args.min_coverage_mean,
- }
- else:
- filter_conditions = {
- "min_read_length": args.min_read_length,
- "min_read_count": args.min_read_count,
- "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
- "min_breadth": args.min_breadth,
- "min_avg_read_ani": args.min_avg_read_ani,
- "min_coverage_evenness": args.min_coverage_evenness,
- "min_coeff_var": args.min_coeff_var,
- "min_coverage_mean": args.min_coverage_mean,
- "min_norm_entropy": min_norm_entropy,
- "min_norm_gini": min_norm_gini,
- }
- else:
- min_norm_gini = 0.5
- min_norm_entropy = 0.75
- logging.warning(
- f"There's only one genome. Using min_norm_gini={min_norm_gini} and min_norm_entropy={min_norm_entropy}. Please check the results."
- )
- filter_conditions = {
- "min_read_length": args.min_read_length,
- "min_read_count": args.min_read_count,
- "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
- "min_breadth": args.min_breadth,
- "min_avg_read_ani": args.min_avg_read_ani,
- "min_coverage_evenness": args.min_coverage_evenness,
- "min_coeff_var": args.min_coeff_var,
- "min_coverage_mean": args.min_coverage_mean,
- "min_norm_entropy": min_norm_entropy,
- "min_norm_gini": min_norm_gini,
- }
- else:
- min_norm_gini, min_norm_entropy = args.min_norm_gini, args.min_norm_entropy
- filter_conditions = {
- "min_read_length": args.min_read_length,
- "min_read_count": args.min_read_count,
- "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
- "min_breadth": args.min_breadth,
- "min_avg_read_ani": args.min_avg_read_ani,
- "min_coverage_evenness": args.min_coverage_evenness,
- "min_coeff_var": args.min_coeff_var,
- "min_coverage_mean": args.min_coverage_mean,
- "min_norm_entropy": min_norm_entropy,
- "min_norm_gini": min_norm_gini,
- }
-
- if args.stats_filtered is not None:
- filter_reference_BAM(
- bam=bam,
- df=data_df,
- filter_conditions=filter_conditions,
- transform_cov_evenness=args.transform_cov_evenness,
- threads=args.threads,
- out_files=out_files,
- sort_memory=args.sort_memory,
- sort_by_name=args.sort_by_name,
- min_read_ani=args.min_read_ani,
- disable_sort=args.disable_sort,
- )
- else:
- logging.info("Skipping filtering of reference BAM file.")
- if args.low_memory:
- os.remove(out_files["bam_tmp_sorted"])
- logging.info("ALL DONE.")
+ if args.action == "reassign":
+ reassign(args)
+ elif args.action == "filter":
+ filter_references(args)
+ elif args.action == "lca":
+ do_lca(args)
if __name__ == "__main__":
diff --git a/bam_filter/filter.py b/bam_filter/filter.py
new file mode 100644
index 0000000..4e47bdb
--- /dev/null
+++ b/bam_filter/filter.py
@@ -0,0 +1,252 @@
+"""
+This program is free software: you can redistribute it and/or modify it under the terms of the GNU
+General Public License as published by the Free Software Foundation, either version 3 of the
+License, or (at your option) any later version.
+
+This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
+even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+General Public License for more details.
+
+You should have received a copy of the GNU General Public License along with this program. If not,
+see .
+"""
+
+
+import logging
+import pandas as pd
+from bam_filter.sam_utils import process_bam, filter_reference_BAM, check_bam_file
+from bam_filter.utils import (
+ get_arguments,
+ create_output_files,
+ concat_df,
+ check_tmp_dir_exists,
+ handle_warning,
+)
+from bam_filter.entropy import find_knee
+import json
+import warnings
+from collections import Counter
+from functools import reduce
+import os
+
+log = logging.getLogger("my_logger")
+
+
+def obj_dict(obj):
+ return obj.__dict__
+
+
+def get_summary(obj):
+ return obj.to_summary()
+
+
+def get_lens(obj):
+ return obj.get_read_length_freqs()
+
+
+def filter_references(args):
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format="%(levelname)s ::: %(asctime)s ::: %(message)s",
+ datefmt="%H:%M:%S",
+ )
+
+ args = get_arguments()
+
+ tmp_dir = check_tmp_dir_exists(args.tmp_dir)
+ log.info("Temporary directory: %s", tmp_dir.name)
+ if args.trim_min >= args.trim_max:
+ log.error("trim_min must be less than trim_max")
+ exit(1)
+
+ logging.getLogger("my_logger").setLevel(
+ logging.DEBUG if args.debug else logging.INFO
+ )
+
+ if args.debug:
+ warnings.showwarning = handle_warning
+ else:
+ warnings.filterwarnings("ignore")
+
+ plot_coverage = False
+ if args.coverage_plots is not None:
+ plot_coverage = True
+
+ if args.stats_filtered is None and (args.bam_filtered is not None):
+ logging.error(
+ "You need to specify a filtereds stats file to obtain the filtered BAM file"
+ )
+ exit(1)
+
+ out_files = create_output_files(
+ prefix=args.prefix,
+ bam=args.bam,
+ stats=args.stats,
+ stats_filtered=args.stats_filtered,
+ bam_filtered=args.bam_filtered,
+ read_length_freqs=args.read_length_freqs,
+ read_hits_count=args.read_hits_count,
+ knee_plot=args.knee_plot,
+ coverage_plots=args.coverage_plots,
+ tmp_dir=tmp_dir,
+ )
+
+ bam = check_bam_file(
+ bam=args.bam,
+ threads=args.threads,
+ reference_lengths=args.reference_lengths,
+ sort_memory=args.sort_memory,
+ )
+ data = process_bam(
+ bam=bam,
+ threads=args.threads,
+ reference_lengths=args.reference_lengths,
+ min_read_count=args.min_read_count,
+ min_read_ani=args.min_read_ani,
+ trim_ends=args.trim_ends,
+ trim_min=args.trim_min,
+ trim_max=args.trim_max,
+ scale=args.scale,
+ plot=plot_coverage,
+ plots_dir=out_files["coverage_plot_dir"],
+ read_length_freqs=args.read_length_freqs,
+ output_files=out_files,
+ sort_memory=args.sort_memory,
+ low_memory=args.low_memory,
+ )
+ if args.low_memory:
+ bam = out_files["bam_tmp_sorted"]
+
+ logging.info("Reducing results to a single dataframe")
+ # data = list(filter(None, data))
+ data_df = [x[0] for x in data if x[0] is not None]
+ data_df = concat_df(data_df)
+
+ if args.read_length_freqs is not None:
+ logging.info("Calculating read length frequencies...")
+ lens = [x[1] for x in data if x[1] is not None]
+ lens = json.dumps(lens, default=obj_dict, ensure_ascii=False, indent=4)
+ with open(out_files["read_length_freqs"], "w", encoding="utf-8") as outfile:
+ print(lens, file=outfile)
+
+ if args.read_hits_count is not None:
+ logging.info("Calculating read hits counts...")
+ hits = [x[2] for x in data if x[2] is not None]
+
+ # merge dicts and sum values
+ hits = reduce(lambda x, y: x.update(y) or x, (Counter(dict(x)) for x in hits))
+ # hits = sum(map(Counter, hits), Counter())
+
+ # convert dict to dataframe
+ hits = (
+ pd.DataFrame.from_dict(hits, orient="index", columns=["count"])
+ .rename_axis("read")
+ .reset_index()
+ .sort_values(by="count", ascending=False)
+ )
+
+ hits.to_csv(
+ out_files["read_hits_count"],
+ sep="\t",
+ index=False,
+ )
+
+ logging.info(f"Writing reference statistics to {out_files['stats']}")
+ data_df.to_csv(out_files["stats"], sep="\t", index=False)
+
+ if args.min_norm_entropy is None or args.min_norm_gini is None:
+ filter_conditions = {
+ "min_read_length": args.min_read_length,
+ "min_read_count": args.min_read_count,
+ "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
+ "min_breadth": args.min_breadth,
+ "min_avg_read_ani": args.min_avg_read_ani,
+ "min_coverage_evenness": args.min_coverage_evenness,
+ "min_coeff_var": args.min_coeff_var,
+ "min_coverage_mean": args.min_coverage_mean,
+ }
+ elif args.min_norm_entropy == "auto" or args.min_norm_gini == "auto":
+ if data_df.shape[0] > 1:
+ min_norm_gini, min_norm_entropy = find_knee(
+ data_df, out_plot_name=out_files["knee_plot"]
+ )
+
+ if min_norm_gini is None or min_norm_entropy is None:
+ logging.warning(
+ "Could not find knee in entropy plot. Disabling filtering by entropy/gini inequality."
+ )
+ filter_conditions = {
+ "min_read_length": args.min_read_length,
+ "min_read_count": args.min_read_count,
+ "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
+ "min_breadth": args.min_breadth,
+ "min_avg_read_ani": args.min_avg_read_ani,
+ "min_coverage_evenness": args.min_coverage_evenness,
+ "min_coeff_var": args.min_coeff_var,
+ "min_coverage_mean": args.min_coverage_mean,
+ }
+ else:
+ filter_conditions = {
+ "min_read_length": args.min_read_length,
+ "min_read_count": args.min_read_count,
+ "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
+ "min_breadth": args.min_breadth,
+ "min_avg_read_ani": args.min_avg_read_ani,
+ "min_coverage_evenness": args.min_coverage_evenness,
+ "min_coeff_var": args.min_coeff_var,
+ "min_coverage_mean": args.min_coverage_mean,
+ "min_norm_entropy": min_norm_entropy,
+ "min_norm_gini": min_norm_gini,
+ }
+ else:
+ min_norm_gini = 0.5
+ min_norm_entropy = 0.75
+ logging.warning(
+ f"There's only one genome. Using min_norm_gini={min_norm_gini} and min_norm_entropy={min_norm_entropy}. Please check the results."
+ )
+ filter_conditions = {
+ "min_read_length": args.min_read_length,
+ "min_read_count": args.min_read_count,
+ "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
+ "min_breadth": args.min_breadth,
+ "min_avg_read_ani": args.min_avg_read_ani,
+ "min_coverage_evenness": args.min_coverage_evenness,
+ "min_coeff_var": args.min_coeff_var,
+ "min_coverage_mean": args.min_coverage_mean,
+ "min_norm_entropy": min_norm_entropy,
+ "min_norm_gini": min_norm_gini,
+ }
+ else:
+ min_norm_gini, min_norm_entropy = args.min_norm_gini, args.min_norm_entropy
+ filter_conditions = {
+ "min_read_length": args.min_read_length,
+ "min_read_count": args.min_read_count,
+ "min_expected_breadth_ratio": args.min_expected_breadth_ratio,
+ "min_breadth": args.min_breadth,
+ "min_avg_read_ani": args.min_avg_read_ani,
+ "min_coverage_evenness": args.min_coverage_evenness,
+ "min_coeff_var": args.min_coeff_var,
+ "min_coverage_mean": args.min_coverage_mean,
+ "min_norm_entropy": min_norm_entropy,
+ "min_norm_gini": min_norm_gini,
+ }
+
+ if args.stats_filtered is not None:
+ filter_reference_BAM(
+ bam=bam,
+ df=data_df,
+ filter_conditions=filter_conditions,
+ transform_cov_evenness=args.transform_cov_evenness,
+ threads=args.threads,
+ out_files=out_files,
+ sort_memory=args.sort_memory,
+ sort_by_name=args.sort_by_name,
+ min_read_ani=args.min_read_ani,
+ disable_sort=args.disable_sort,
+ )
+ else:
+ logging.info("Skipping filtering of reference BAM file.")
+
+ if args.low_memory:
+ os.remove(out_files["bam_tmp_sorted"])
+ logging.info("ALL DONE.")
diff --git a/bam_filter/lca.py b/bam_filter/lca.py
new file mode 100644
index 0000000..e291f57
--- /dev/null
+++ b/bam_filter/lca.py
@@ -0,0 +1,585 @@
+import pysam
+import taxopy as txp
+from tqdm import tqdm
+from multiprocessing import Pool, Manager
+from functools import partial
+import pandas as pd
+import logging
+import networkx as nx
+from concurrent.futures import ProcessPoolExecutor
+import sys
+from bam_filter.utils import (
+ calc_chunksize,
+ sort_keys_by_approx_weight,
+ concat_df,
+ is_debug,
+ create_output_files,
+)
+from collections import defaultdict
+from functools import reduce
+import operator
+import random
+import gzip
+
+log = logging.getLogger("my_logger")
+
+debug = is_debug()
+
+
+def calculate_path_likelihood(path, graph):
+ likelihood = 1.0
+ for u, v in zip(path[:-1], path[1:]):
+ likelihood *= graph[u][v]["cum_weight"]
+ return likelihood
+
+
+def find_most_likely_continuation_worker(partial_path_end, full_graph, index):
+ # find where are we in the taxonomy tree, if we are in the root, return
+
+ results = []
+ try:
+ descendants = list(nx.descendants(full_graph, partial_path_end))
+ except nx.NetworkXError:
+ descendants = []
+
+ if not descendants:
+ return (
+ partial_path_end,
+ {
+ "reference": None,
+ "best_path": None,
+ "top_10_paths": None,
+ },
+ )
+
+ if (
+ len(nx.shortest_path(full_graph, source="root", target=partial_path_end))
+ <= index + 1
+ ):
+ print(partial_path_end)
+ return (
+ partial_path_end,
+ {
+ "reference": None,
+ "best_path": None,
+ "top_10_paths": None,
+ },
+ )
+
+ for continuation in descendants:
+ full_path = list(
+ nx.all_simple_paths(
+ full_graph, source=partial_path_end, target=continuation
+ )
+ )[0]
+ likelihood = calculate_path_likelihood(full_path, full_graph)
+ results.append((full_path, likelihood))
+
+ results.sort(key=lambda x: x[1], reverse=True)
+
+ return (
+ partial_path_end,
+ {
+ "reference": results[0][0][-1] if results else None,
+ "best_path": results[0] if results else None,
+ "top_10_paths": results[:10] if results else None,
+ },
+ )
+
+
+def find_most_likely_continuation(full_graph, leaves, index, num_workers=1):
+ result_dict = {}
+
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
+ futures = [
+ executor.submit(
+ find_most_likely_continuation_worker,
+ partial_path_end,
+ full_graph,
+ index,
+ )
+ for partial_path_end in leaves
+ ]
+
+ for future in tqdm(futures, total=len(leaves), leave=False, ncols=80):
+ partial_path_end, result = future.result()
+ result_dict[partial_path_end] = result
+
+ return result_dict
+
+
+def calculate_cumulative_weight_and_path(graph, node, lengths):
+ # Base case: if the node is the root, return its cumulative weight and path
+ if not list(graph.predecessors(node)):
+ if node in lengths:
+ return 0, 0, [node]
+ else:
+ nei = list(graph.neighbors(node))[0]
+ cumulative_weight = graph[node][nei].get("weight", 0)
+ cumulative_norm_weight = graph[node][nei].get("norm_weight", 0)
+ return cumulative_weight, cumulative_norm_weight, [node]
+
+ # Recursive case: calculate cumulative weight and path by summing the edge weight
+ # and norm_weight with the cumulative weight and path of its parent(s)
+ cumulative_weight = 0
+ cumulative_norm_weight = 0
+ cumulative_path = []
+
+ for parent in graph.predecessors(node):
+ edge_weight = graph[parent][node].get("weight", 0)
+ norm_weight = graph[parent][node].get("norm_weight", 0)
+ (
+ parent_weight,
+ parent_norm_weight,
+ parent_path,
+ ) = calculate_cumulative_weight_and_path(graph, parent, lengths)
+
+ cumulative_weight += edge_weight + parent_weight
+ cumulative_norm_weight += norm_weight + parent_norm_weight
+ cumulative_path.extend(parent_path + [node]) # Fix: Use += to concatenate lists
+
+ return cumulative_weight, cumulative_norm_weight, cumulative_path
+
+
+def create_tax_graph_w(tax_path, weight, lengths, scale=1_000_000):
+ root_row = pd.DataFrame({"source": ["root"], "target": ["root"], "weight": 0})
+ res = list(zip(tax_path, tax_path[1:]))
+ # get last element
+ res = pd.DataFrame(res, columns=["source", "target"])
+ res = pd.concat([root_row, res])
+ res = res.drop_duplicates()
+ res["weight"] = 0
+ res["norm_weight"] = 0
+ res.iloc[-1, res.columns.get_loc("weight")] = weight
+ res.iloc[-1, res.columns.get_loc("norm_weight")] = round(
+ (weight / lengths[res.iloc[-1, res.columns.get_loc("target")]]) * scale
+ )
+ return res
+
+
+def create_lca_df(tax_path, weight):
+ root_row = pd.DataFrame({"source": ["root"], "target": ["root"], "weight": 0})
+ res = list(zip(tax_path, tax_path[1:]))
+ # get last element
+ res = pd.DataFrame(res, columns=["source", "target"])
+ res = pd.concat([root_row, res])
+ res = res.drop_duplicates()
+ res["weight"] = 0
+ # add 1 to the last row weight
+ res.iloc[-1, res.columns.get_loc("weight")] = weight
+ return res
+
+
+def get_ref2read(params, dat, threads=1):
+ bam, references = params
+ samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
+ results = defaultdict(set)
+ for reference in references:
+ if reference not in dat:
+ continue
+ for aln in samfile.fetch(
+ contig=reference, multiple_iterators=False, until_eof=True
+ ):
+ results[reference].add(aln.query_name)
+ samfile.close()
+ return results
+
+
+def get_tax(ref, parms):
+ taxdb = parms["taxdb"]
+ acc2taxid = parms["acc2taxid"]
+ if ref in acc2taxid:
+ taxid = acc2taxid[ref]
+ # taxid = txp.taxid_from_name(ref, taxdb)[0]
+ taxonomy_info = txp.Taxon(taxid, taxdb).rank_name_dictionary
+ taxonomy_info["taxid"] = taxid
+ taxonomy_info["ref"] = ref
+ taxonomy_info["subspecies"] = f"S__{ref}"
+ else:
+ log.debug(f"No taxid found for {ref}")
+ taxonomy_info = None
+ return taxonomy_info
+
+
+def get_taxonomy_info(refids, taxdb, acc2taxid, nprocs=1):
+ """Function to get the references taxonomic information for a given taxonomy id
+
+ Args:
+ taxids (list): A list of taxonomy ids
+ taxdb (taxopy.TaxonomyDB): A taxopy DB
+
+ Returns:
+ dict: A list of taxonomy information
+ """
+
+ acc2taxid_df = pd.read_csv(acc2taxid, sep="\t", index_col=None)[
+ ["accession", "taxid"]
+ ].rename(columns={"accession": "reference"}, inplace=False)
+ # Filter rows in refids from dataframe
+ acc2taxid_df = acc2taxid_df.loc[acc2taxid_df["reference"].isin(refids)]
+ acc2taxid_dict = acc2taxid_df.set_index("reference").T.to_dict("records")
+
+ parms = {"taxdb": taxdb, "acc2taxid": acc2taxid_dict[0]}
+ func = partial(get_tax, parms=parms)
+ if debug is True or len(refids) < 100000:
+ taxonomy_info = list(map(func, refids))
+ else:
+ p = Pool(nprocs)
+ c_size = calc_chunksize(nprocs, len(refids))
+ taxonomy_info = list(
+ tqdm(
+ p.imap_unordered(func, refids, chunksize=c_size),
+ total=len(refids),
+ leave=False,
+ ncols=100,
+ desc="References processed",
+ )
+ )
+ p.close()
+ p.join()
+ taxonomy_info = list(filter(None, taxonomy_info))
+ exclude = ["taxid", "ref"]
+ tax_ranks = []
+
+ for k in taxonomy_info[0].keys():
+ if k not in exclude:
+ tax_ranks.append(k)
+
+ taxonomy_info = {i["ref"]: i for i in taxonomy_info}
+ return taxonomy_info, tax_ranks
+
+
+def do_lca(args):
+ bam = args.bam
+ names = args.names
+ nodes = args.nodes
+ acc2taxid = args.acc2taxid
+ sel_rank = args.rank_lca
+ reference_lengths = args.reference_lengths
+ threads = args.threads
+ scale = args.scale
+
+ out_files = create_output_files(
+ bam=bam, prefix=args.prefix, lca_summary=args.lca_summary, tmp_dir=None
+ )
+
+ samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
+ references = samfile.references
+ references_m = {
+ chrom.contig: chrom.mapped for chrom in samfile.get_index_statistics()
+ }
+
+ if reference_lengths is not None:
+ ref_lengths = pd.read_csv(
+ reference_lengths, sep="\t", index_col=0, names=["reference", "length"]
+ )
+ ref_lengths = dict(zip(ref_lengths["reference"], ref_lengths["length"]))
+ # check if the dataframe contains all the References in the BAM file
+ if not set(references).issubset(set(ref_lengths.index)):
+ logging.error(
+ "The BAM file contains references not found in the reference lengths file"
+ )
+ sys.exit(1)
+ else:
+ ref_lengths = {x: samfile.get_reference_length(x) for x in references}
+
+ log.info("Getting taxonomy information")
+ taxdb = txp.TaxDb(
+ nodes_dmp=nodes,
+ names_dmp=names,
+ )
+ acc2taxid = acc2taxid
+ taxonomy_info, tax_ranks = get_taxonomy_info(
+ references, taxdb, acc2taxid, nprocs=threads
+ )
+ tax_ranks.reverse()
+ index_r = tax_ranks.index(sel_rank)
+
+ dat = {}
+ for k, v in taxonomy_info.items():
+ try:
+ tax_list = ["root"]
+ tax_list.extend([v[rank] for rank in tax_ranks])
+ tax_list.append(k)
+ dat[k] = tuple(tax_list)
+ except KeyError:
+ continue
+
+ log.info("Getting reference to read mapping")
+ ref_chunks = sort_keys_by_approx_weight(
+ references_m, scale=1.5, num_cores=threads, refinement_steps=100
+ )
+
+ ref_chunks = random.sample(ref_chunks, len(ref_chunks))
+ dat_man = Manager().dict(dat)
+ params = zip([bam] * len(ref_chunks), ref_chunks)
+ p = Pool(
+ threads,
+ )
+ data = list(
+ tqdm(
+ p.imap_unordered(
+ partial(
+ get_ref2read,
+ dat=dat_man,
+ threads=1,
+ ),
+ params,
+ chunksize=1,
+ ),
+ total=len(ref_chunks),
+ leave=False,
+ ncols=80,
+ desc="Chunks processed",
+ )
+ )
+
+ p.close()
+ p.join()
+
+ data = reduce(operator.ior, data, {})
+
+ log.info("Getting reads taxonomic information")
+ reads = defaultdict(set)
+ for k, v in tqdm(data.items(), total=len(data), leave=False, ncols=80):
+ tax = dat[k]
+ for read in v:
+ reads[read].add(tax)
+
+ log.info("Creating taxonomic graph")
+ root_row = pd.DataFrame({"source": ["root"], "target": ["root"], "weight": [0]})
+ gs = []
+ if len(dat) > 0:
+ for t in tqdm(dat.values(), total=len(dat), leave=False, ncols=80):
+ # remove last element
+ # t = t[:-1]
+ res = list(zip(t, t[1:]))
+ res = pd.DataFrame(res, columns=["source", "target"])
+ res = res.drop_duplicates()
+ res["weight"] = 0
+ res = pd.concat([root_row, res])
+ gs.append(res)
+ gs = concat_df(gs).drop_duplicates()
+
+ G = nx.from_pandas_edgelist(gs, create_using=nx.DiGraph(), edge_attr=True)
+
+ dat = None
+
+ log.info("Calculating LCA")
+ cum_tax = defaultdict(int)
+ unique = []
+
+ discarded_lca = defaultdict(int)
+ for_lca = defaultdict(int)
+
+ # for k, v in tqdm(reads.items(), total=len(reads), leave=False, ncols=80):
+ # v = list(v)
+ # if len(v) > 1:
+ # tax_path = [edge for edge in v[0] if all(edge in t for t in v[1:])]
+ # if len(tax_path) <= index_r + 1:
+ # discarded_lca[tuple(tax_path)] += 1
+ # continue
+ # for_lca[tuple(tax_path)] += 1
+ # else:
+ # unique.append(k)
+ # cum_tax[v[0]] += 1
+
+ for k, v in tqdm(reads.items(), total=len(reads), leave=False, ncols=80):
+ v = list(v)
+ if len(v) > 1:
+ tax_path_set = set(v[0])
+ for t in v[1:]:
+ tax_path_set.intersection_update(t)
+
+ tax_path = list(tax_path_set)
+
+ if len(tax_path) <= index_r + 1:
+ discarded_lca[tuple(tax_path)] += 1
+ continue
+
+ for_lca[tuple(tax_path)] += 1
+ else:
+ unique.append(k)
+ cum_tax[v[0]] += 1
+
+ log.info(
+ f"Unique: {len(unique)} | LCA: {len(for_lca)} | Discarded: {len(discarded_lca)}"
+ )
+ log.info("Creating LCA dataframe")
+ if len(for_lca) > 0:
+ for_lca_df = []
+ for k, v in tqdm(for_lca.items(), total=len(for_lca), leave=False, ncols=80):
+ for_lca_df.append(create_lca_df(list(k), v))
+
+ log.info("Adding weights to the unique mapping taxonomic graph")
+ gs = []
+ for k, v in tqdm(cum_tax.items(), total=len(cum_tax), leave=False, ncols=80):
+ gs.append(create_tax_graph_w(list(k), v, ref_lengths, scale=scale))
+
+ df = concat_df(gs)
+ if len(for_lca) > 0:
+ df_l = concat_df(for_lca_df)
+ df_l["weight"] = 0
+ df_l["norm_weight"] = 0
+ df = concat_df([df, df_l])
+ # group by source and target and sum weight
+ df = df.groupby(["source", "target"]).sum().reset_index()
+
+ G = nx.from_pandas_edgelist(df, create_using=nx.DiGraph(), edge_attr=True)
+ G.remove_edges_from(nx.selfloop_edges(G))
+ Gr = G.reverse()
+
+ log.info("Calculating cumulative weights and paths on the taxonomic graph")
+ cumulative_weights_and_paths = {}
+ modified_graph = Gr.copy() # Create a copy of the original graph
+
+ for node in Gr.nodes:
+ (
+ cumulative_weight,
+ cumulative_norm_weight,
+ path,
+ ) = calculate_cumulative_weight_and_path(modified_graph, node, ref_lengths)
+ cumulative_weights_and_paths[node] = {
+ "cumulative_weight": cumulative_weight,
+ "cumulative_norm_weight": cumulative_norm_weight,
+ "path": ";".join(nx.shortest_path(G, target=node)["root"]),
+ }
+
+ # Update the modified graph with inferred edge attributes after all weights have been calculated
+ for parent in modified_graph.predecessors(node):
+ modified_graph[parent][node]["cum_weight"] = (
+ modified_graph[parent][node].get("weight", 0) + cumulative_weight
+ )
+ modified_graph[parent][node]["cum_norm_weight"] = (
+ modified_graph[parent][node].get("norm_weight", 0)
+ + cumulative_norm_weight
+ )
+
+ cumulative_weights_and_paths = dict(
+ sorted(cumulative_weights_and_paths.items(), key=lambda item: item[1]["path"])
+ )
+
+ if len(for_lca) > 0:
+ df1 = concat_df(for_lca_df)
+ df1 = df1.groupby(["source", "target"]).sum().reset_index()
+ G_lca = nx.from_pandas_edgelist(df1, create_using=nx.DiGraph(), edge_attr=True)
+ G_lca.remove_edges_from(nx.selfloop_edges(G_lca))
+
+ df1_d = dict(
+ zip(df1[df1["weight"] > 0]["target"], df1[df1["weight"] > 0]["weight"])
+ )
+ leaves = list(df1_d.keys())
+
+ log.info(
+ f"Finding most likely reference for the LCA nodes [{threads} threads]]"
+ )
+ test = find_most_likely_continuation(
+ full_graph=modified_graph.reverse(),
+ leaves=leaves,
+ index=index_r,
+ num_workers=threads,
+ )
+
+ lca_dfs = []
+ for k, v in test.items():
+ tax_path = nx.shortest_path(G, target=k)["root"]
+ root_row = pd.DataFrame(
+ {"source": ["root"], "target": ["root"], "weight": 0}
+ )
+ res = list(zip(tax_path, tax_path[1:]))
+ # get last element
+ res = pd.DataFrame(res, columns=["source", "target"])
+ res = pd.concat([root_row, res])
+ res = res.drop_duplicates()
+ res["weight"] = 0
+ res["norm_weight"] = 0
+ res.iloc[-1, res.columns.get_loc("weight")] = df1_d[k]
+ if v["reference"] is None:
+ res.iloc[-1, res.columns.get_loc("norm_weight")] = df1_d[k]
+ else:
+ res.iloc[-1, res.columns.get_loc("norm_weight")] = round(
+ scale * df1_d[k] / ref_lengths[v["reference"]]
+ )
+ lca_dfs.append(res)
+ log.info("Adding LCA nodes to the taxonomic graph")
+ df2 = concat_df(lca_dfs)
+ df2 = concat_df([df2, df])
+ df2 = df2.groupby(["source", "target"]).sum().reset_index()
+ # %%
+ G2 = nx.from_pandas_edgelist(df2, create_using=nx.DiGraph(), edge_attr=True)
+ G2.remove_edges_from(nx.selfloop_edges(G2))
+ Gr2 = G2.reverse()
+ # %%
+ log.info("Calculating cumulative weights and paths on the taxonomic graph")
+ cumulative_weights_and_paths = {}
+ modified_graph = Gr2.copy() # Create a copy of the original graph
+
+ for node in Gr2.nodes:
+ (
+ cumulative_weight,
+ cumulative_norm_weight,
+ path,
+ ) = calculate_cumulative_weight_and_path(modified_graph, node, ref_lengths)
+
+ cumulative_weights_and_paths[node] = {
+ "cumulative_weight": cumulative_weight,
+ "cumulative_norm_weight": cumulative_norm_weight,
+ "path": ";".join(nx.shortest_path(G, target=node)["root"]),
+ }
+
+ # Update the modified graph with inferred edge attributes after all weights have been calculated
+ for parent in modified_graph.predecessors(node):
+ modified_graph[parent][node]["cum_weight"] = (
+ modified_graph[parent][node].get("weight", 0) + cumulative_weight
+ )
+ modified_graph[parent][node]["cum_norm_weight"] = (
+ modified_graph[parent][node].get("norm_weight", 0)
+ + cumulative_norm_weight
+ )
+ cumulative_weights_and_paths = dict(
+ sorted(
+ cumulative_weights_and_paths.items(), key=lambda item: item[1]["path"]
+ )
+ )
+
+ log.info("Writing LCA results to file")
+ nodes = list(
+ set(
+ [
+ node
+ for node, data in cumulative_weights_and_paths.items()
+ if data["cumulative_weight"] > 0
+ ]
+ )
+ )
+ taxids = txp.taxid_from_name(nodes, taxdb)
+ taxids = dict(zip(nodes, taxids))
+
+ out_file = out_files["lca_summary"]
+ # determine if the file is gzipped or not
+ if out_file.endswith(".gz"):
+ with gzip.open(out_file, "wt") as f:
+ f.write("taxid\tname\trank\tn_reads\tabundance\ttax_path\n")
+ for node, data in tqdm(
+ cumulative_weights_and_paths.items(),
+ total=len(cumulative_weights_and_paths),
+ ):
+ if data["cumulative_weight"] > 0:
+ taxid = taxids[node][0]
+ rank = taxdb.taxid2rank[taxid]
+ f.write(
+ f"{taxid}\t{node}\t{rank}\t{data['cumulative_weight']}\t{data['cumulative_norm_weight']}\t{data['path']}\n"
+ )
+ else:
+ with open(out_file, "wt") as f:
+ f.write("taxid\tname\trank\tn_reads\tabundance\ttax_path\n")
+ for node, data in tqdm(
+ cumulative_weights_and_paths.items(),
+ total=len(cumulative_weights_and_paths),
+ ):
+ if data["cumulative_weight"] > 0:
+ taxid = taxids[node][0]
+ rank = taxdb.taxid2rank[taxid]
+ f.write(
+ f"{taxid}\t{node}\t{rank}\t{data['cumulative_weight']}\t{data['cumulative_norm_weight']}\t{data['path']}\n"
+ )
diff --git a/bam_filter/reassign.py b/bam_filter/reassign.py
new file mode 100644
index 0000000..69afbe5
--- /dev/null
+++ b/bam_filter/reassign.py
@@ -0,0 +1,785 @@
+import datatable as dt
+import logging
+import numpy as np
+import pysam
+import tqdm
+import sys
+import random
+from bam_filter.utils import (
+ create_empty_output_files,
+ sort_keys_by_approx_weight,
+ is_debug,
+ get_arguments,
+ check_tmp_dir_exists,
+ handle_warning,
+ create_output_files,
+)
+from multiprocessing import Pool, Manager, cpu_count
+from functools import partial
+import numpy.lib.recfunctions as rf
+import gc
+from collections import defaultdict
+import os
+import concurrent.futures
+import math
+import warnings
+
+# import cProfile as prof
+# import pstats
+
+log = logging.getLogger("my_logger")
+
+
+def initialize_subject_weights(data):
+ if data.shape[0] > 0:
+ # Add a new column for inverse sequence length
+ data["s_W"] = 1 / data["slen"]
+
+ # Calculate the sum of weights for each unique source
+ sum_weights = np.zeros(int(np.max(data["source"])) + 1)
+ np.add.at(sum_weights, data["source"], data["var"])
+
+ # Calculate the normalized weights based on the sum for each query
+ query_sum_var = sum_weights[data["source"]]
+ data["prob"] = data["var"] / query_sum_var
+
+ return data
+ else:
+ return None
+
+
+def resolve_multimaps(data, scale=0.9, iters=10):
+ current_iter = 0
+ while True:
+ progress_bar = tqdm.tqdm(
+ total=9,
+ desc=f"Iter {current_iter + 1}",
+ unit="step",
+ disable=False, # Replace with your logic or a boolean value
+ leave=False,
+ ncols=80,
+ )
+ log.debug(f"::: Iter: {current_iter + 1} - Getting scores")
+ progress_bar.update(1)
+ n_alns = data.shape[0]
+ log.debug(f"::: Iter: {current_iter + 1} - Total alignment: {n_alns:,}")
+
+ # Calculate the weights for each subject
+ log.debug(f"::: Iter: {current_iter + 1} - Calculating weights...")
+ progress_bar.update(1)
+ subject_weights = np.zeros(int(np.max(data["subject"])) + 1)
+ np.add.at(subject_weights, data["subject"], data["prob"])
+ data["s_W"] = subject_weights[data["subject"]] / data["slen"]
+ subject_weights = None
+
+ log.debug(f"::: Iter: {current_iter + 1} - Calculating probabilities")
+ progress_bar.update(1)
+ # Calculate the alignment probabilities
+ new_prob = data["prob"] * data["s_W"]
+ log.debug("Calculating sum of probabilities")
+ progress_bar.update(1)
+ prob_sum = data["prob"] * data["s_W"]
+ prob_sum_array = np.zeros(int(np.max(data["source"])) + 1)
+ np.add.at(prob_sum_array, data["source"], prob_sum)
+ prob_sum = None
+
+ # data["prob_sum"] = prob_sum_array[data["source"]]
+ data["prob"] = new_prob / prob_sum_array[data["source"]]
+ prob_sum_array = None
+
+ log.debug("Calculating query counts")
+ progress_bar.update(1)
+ # Calculate how many alignments are in each query
+ # query_counts = np.zeros(int(np.max(data["source"])) + 1)
+ # np.add.at(query_counts, data["source"], 1)
+
+ query_counts = np.bincount(data["source"])
+
+ log.debug("Calculating query counts array")
+ progress_bar.update(1)
+ # Use a separate array for query counts
+ query_counts_array = np.zeros(int(np.max(data["source"])) + 1)
+ np.add.at(
+ query_counts_array,
+ data["source"],
+ query_counts[data["source"]],
+ )
+
+ log.debug(
+ f"::: Iter: {current_iter + 1} - Calculating number of alignments per query"
+ )
+ progress_bar.update(1)
+ data["n_aln"] = query_counts_array[data["source"]]
+
+ log.debug("Calculating unique alignments")
+ data["n_aln"] = query_counts_array[data["source"]]
+ data_unique = data[data["n_aln"] == 1]
+ n_unique = data_unique.shape[0]
+
+ data = data[(data["n_aln"] > 1) & (data["prob"] > 0)]
+
+ # total_n_unique = np.sum(query_counts_array[data["source"]] <= 1)
+
+ query_counts = None
+ query_counts_array = None
+
+ log.debug("Calculating max_prob")
+ # Keep the ones that have a probability higher than the maximum scaled probability
+ max_prob = np.zeros(int(np.max(data["source"])) + 1)
+ np.maximum.at(max_prob, data["source"], data["prob"])
+
+ data["max_prob"] = max_prob[data["source"]]
+ data["max_prob"] = data["max_prob"] * scale
+ # data["max_prob"] = max_prob[data["source"]]
+ log.debug(
+ f"::: Iter: {current_iter + 1} - Removing alignments with lower probability"
+ )
+ progress_bar.update(1)
+ to_remove = np.sum(data["prob"] < data["max_prob"])
+
+ data = data[data["prob"] >= data["max_prob"]]
+ max_prob = None
+
+ # Update the iteration count in the function call
+ current_iter += 1
+ data["iter"] = current_iter
+ data_unique["iter"] = current_iter
+
+ query_counts = np.bincount(data["source"])
+ total_n_unique = np.sum(query_counts[data["source"]] <= 1)
+
+ # data_unique["iter"] = current_iter
+
+ # data = np.concatenate([data, data_unique])
+ data = np.concatenate([data, data_unique])
+ data_unique = None
+
+ keep_processing = to_remove != 0
+ log.debug(f"::: Iter: {current_iter} - Removed {to_remove:,} alignments")
+ log.debug(f"::: Iter: {current_iter} - Total mapping queries: {n_unique:,}")
+ log.debug(
+ f"::: Iter: {current_iter} - New unique mapping queries: {total_n_unique:,}"
+ )
+ log.debug(f"::: Iter: {current_iter} - Alns left: {data.shape[0]:,}")
+ progress_bar.update(1)
+ progress_bar.close()
+ log.info(
+ f"::: Iter: {current_iter} - R: {to_remove:,} | U: {total_n_unique:,} | NU: {n_unique:,} | L: {data.shape[0]:,}"
+ )
+ log.debug(f"::: Iter: {current_iter} - done!")
+
+ if iters > 0 and current_iter >= iters:
+ log.info("::: ::: Reached maximum iterations. Stopping.")
+ break
+ elif not keep_processing:
+ log.info("::: ::: No more alignments to remove. Stopping.")
+ break
+ return data
+
+
+# def write_reassigned_bam(
+# bam, out_files, threads, entries, sort_memory="1G", min_read_ani=90
+# ):
+# if out_files["bam_reassigned"] is not None:
+# out_bam = out_files["bam_reassigned"]
+# else:
+# out_bam = out_files["bam_reassigned_sorted"]
+
+# samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
+# references = list(entries.keys())
+# refs_dict = {x: samfile.get_reference_length(x) for x in list(entries.keys())}
+
+# (ref_names, ref_lengths) = zip(*refs_dict.items())
+
+# refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)}
+# if threads > 4:
+# write_threads = 4
+# else:
+# write_threads = threads
+
+# out_bam_file = pysam.AlignmentFile(
+# out_files["bam_reassigned_tmp"],
+# "wb",
+# referencenames=list(ref_names),
+# referencelengths=list(ref_lengths),
+# threads=write_threads,
+# )
+
+# for reference in tqdm.tqdm(
+# references,
+# total=len(references),
+# leave=False,
+# ncols=80,
+# desc="References processed",
+# ):
+# r_ids = entries[reference]
+# for aln in samfile.fetch(
+# reference=reference, multiple_iterators=False, until_eof=True
+# ):
+# # ani_read = (1 - ((aln.get_tag("NM") / aln.infer_query_length()))) * 100
+# if (aln.query_name, reference) in r_ids:
+# aln.reference_id = refs_idx[aln.reference_name]
+# out_bam_file.write(aln)
+# out_bam_file.close()
+
+
+# def write_to_file(alns, out_bam_file):
+# for aln in tqdm.tqdm(alns, total=len(alns), leave=False, ncols=80, desc="Writing"):
+# out_bam_file.write(aln)
+
+
+def write_to_file(alns, out_bam_file, header=None):
+ for aln in alns:
+ out_bam_file.write(pysam.AlignedSegment.fromstring(aln, header))
+
+
+def process_references_batch(references, entries, bam, refs_idx, threads=4):
+ alns = []
+ with pysam.AlignmentFile(bam, "rb", threads=threads) as samfile:
+ for reference in references:
+ r_ids = entries[reference]
+ for aln in samfile.fetch(
+ reference=reference, multiple_iterators=False, until_eof=True
+ ):
+ if (aln.query_name, reference) in r_ids:
+ aln.reference_id = refs_idx[aln.reference_name]
+ alns.append(aln.to_string())
+
+ return alns
+
+
+def write_reassigned_bam(
+ bam,
+ ref_counts,
+ out_files,
+ threads,
+ entries,
+ sort_memory="1G",
+ sort_by_name=False,
+ min_read_ani=90,
+ min_read_length=30,
+):
+ # if out_files["bam_reassigned"] is not None:
+ # out_bam = out_files["bam_reassigned"]
+ # else:
+ # out_bam = out_files["bam_reassigned_sorted"]
+ out_bam = out_files["bam_reassigned"]
+ samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
+ references = list(entries.keys())
+ refs_dict = {x: samfile.get_reference_length(x) for x in references}
+ header = samfile.header
+ samfile.close()
+ (ref_names, ref_lengths) = zip(*refs_dict.items())
+
+ refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)}
+ if threads > 4:
+ write_threads = 4
+ else:
+ write_threads = threads
+
+ out_bam_file = pysam.AlignmentFile(
+ out_files["bam_reassigned_tmp"],
+ "wb",
+ referencenames=list(ref_names),
+ referencelengths=list(ref_lengths),
+ threads=write_threads,
+ )
+
+ num_cores = min(threads, cpu_count())
+ # batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size
+ # batch_size = calc_chunksize(n_workers=num_cores, len_iterable=len(references))
+ log.info("::: Creating reference chunks with uniform read amounts...")
+ ref_chunks = sort_keys_by_approx_weight(
+ input_dict=ref_counts, scale=1.5, num_cores=threads
+ )
+ num_cores = min(num_cores, len(ref_chunks))
+ log.info(f"::: Using {num_cores} cores to write {len(ref_chunks)} chunk(s)")
+
+ with Manager() as manager:
+ # Use Manager to create a read-only proxy for the dictionary
+ entries = manager.dict(dict(entries))
+
+ with concurrent.futures.ProcessPoolExecutor(max_workers=num_cores) as executor:
+ # Use ProcessPoolExecutor to parallelize the processing of references in batches
+ futures = []
+ for batch_references in tqdm.tqdm(
+ ref_chunks,
+ total=len(ref_chunks),
+ desc="Submitted batches",
+ unit="batch",
+ leave=False,
+ ncols=80,
+ disable=is_debug(),
+ ):
+ future = executor.submit(
+ process_references_batch, batch_references, entries, bam, refs_idx
+ )
+ futures.append(future) # Store the future
+
+ # Use a while loop to continuously check for completed futures
+ log.info("::: Collecting batches...")
+
+ completion_progress_bar = tqdm.tqdm(
+ total=len(futures),
+ desc="Completed",
+ unit="batch",
+ leave=False,
+ ncols=80,
+ disable=is_debug(),
+ )
+ completed_count = 0
+
+ # Use as_completed to iterate over completed futures as they become available
+ for completed_future in concurrent.futures.as_completed(futures):
+ alns = completed_future.result()
+ write_to_file(alns=alns, out_bam_file=out_bam_file, header=header)
+
+ # Update the progress bar for each completed write
+ completion_progress_bar.update(1)
+ completed_count += 1
+ completed_future.cancel() # Cancel the future to free memory
+ gc.collect() # Force garbage collection
+
+ completion_progress_bar.close()
+ out_bam_file.close()
+ entries = None
+ gc.collect()
+ # prof.disable()
+ # # print profiling output
+ # stats = pstats.Stats(prof).strip_dirs().sort_stats("tottime")
+ # stats.print_stats(5) # top 10 rows
+ log.info("::: ::: Sorting BAM file...")
+ if sort_by_name:
+ log.info("::: ::: Sorting by name...")
+ pysam.sort(
+ "-n",
+ "-@",
+ str(threads),
+ "-m",
+ str(sort_memory),
+ "-o",
+ out_bam,
+ out_files["bam_reassigned_tmp"],
+ )
+ else:
+ pysam.sort(
+ "-@",
+ str(threads),
+ "-m",
+ str(sort_memory),
+ "-o",
+ out_bam,
+ out_files["bam_reassigned_tmp"],
+ )
+
+ logging.info("BAM index not found. Indexing...")
+ save = pysam.set_verbosity(0)
+ samfile = pysam.AlignmentFile(out_bam, "rb", threads=threads)
+ chr_lengths = []
+ for chrom in samfile.references:
+ chr_lengths.append(samfile.get_reference_length(chrom))
+ max_chr_length = np.max(chr_lengths)
+ pysam.set_verbosity(save)
+ samfile.close()
+
+ if max_chr_length > 536870912:
+ logging.info("A reference is longer than 2^29, indexing with csi")
+ pysam.index(
+ "-c",
+ "-@",
+ str(threads),
+ out_bam,
+ )
+ else:
+ pysam.index(
+ "-@",
+ str(threads),
+ out_bam,
+ )
+
+ os.remove(out_files["bam_reassigned_tmp"])
+
+
+def calculate_alignment_score(identity, read_length):
+ return (identity / math.log(read_length + 1)) * math.sqrt(read_length)
+
+
+def get_bam_data(parms, ref_lengths=None, percid=90, min_read_length=30, threads=1):
+ bam, references = parms
+ samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
+
+ results = []
+ reads = set()
+ refs = set()
+ empty_df = 0
+
+ for reference in references:
+ if ref_lengths is None:
+ reference_length = int(samfile.get_reference_length(reference))
+ else:
+ reference_length = int(ref_lengths[reference])
+ aln_data = []
+ for aln in samfile.fetch(
+ contig=reference, multiple_iterators=False, until_eof=True
+ ):
+ # AS_value = aln.get_tag("AS") if aln.has_tag("AS") else None
+ # XS_value = aln.get_tag("XS") if aln.has_tag("XS") else None
+ query_length = (
+ aln.query_length if aln.query_length != 0 else aln.infer_query_length()
+ )
+
+ # if query_length < min_read_length:
+ # continue
+
+ pident = (1 - ((aln.get_tag("NM") / query_length))) * 100
+ if pident >= percid:
+ var = calculate_alignment_score(pident, query_length)
+ aln_data.append(
+ (
+ aln.query_name,
+ aln.reference_name,
+ var,
+ reference_length,
+ )
+ )
+ reads.add(aln.query_name)
+ refs.add(aln.reference_name)
+ # remove duplicates
+ # check if aln_data is not empty
+ if len(aln_data) > 0:
+ aln_data_dt = dt.Frame(aln_data)
+ # "queryId", "subjectId", "bitScore", "slen"
+ aln_data_dt.names = ["queryId", "subjectId", "var", "slen"]
+ # remove duplicates and keep the ones with the largest mapq
+ aln_data_dt = aln_data_dt[
+ :1, :, dt.by(dt.f.queryId, dt.f.subjectId), dt.sort(-dt.f.var)
+ ]
+ # aln_data_dt["slen"] = dt.Frame([ref_lengths[x] for x in aln_data_dt[:,"subjectId"].to_list()[0]])
+ results.append(aln_data_dt)
+ else:
+ results.append(dt.Frame())
+ empty_df += 1
+ samfile.close()
+ return (dt.rbind(results), reads, refs, empty_df)
+
+
+def reassign_reads(
+ bam,
+ out_files,
+ reference_lengths=None,
+ threads=1,
+ min_read_count=1,
+ min_read_ani=90,
+ min_read_length=30,
+ reassign_iters=25,
+ reassign_scale=0.9,
+ sort_memory="4G",
+ sort_by_name=False,
+):
+ dt.options.progress.enabled = True
+ dt.options.progress.clear_on_success = True
+ if threads > 1:
+ dt.options.nthreads = threads - 1
+ else:
+ dt.options.nthreads = 1
+
+ log.info("::: Loading BAM file")
+ save = pysam.set_verbosity(0)
+ samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
+
+ references = samfile.references
+
+ pysam.set_verbosity(save)
+
+ if reference_lengths is not None:
+ ref_len_dt = dt.fread(reference_lengths)
+ ref_len_dt.names = ["subjectId", "slen"]
+ # convert to dict
+ ref_len_dict = dict(
+ zip(ref_len_dt["subjectId"].to_list()[0], ref_len_dt["slen"].to_list()[0])
+ )
+ # check if the dataframe contains all the References in the BAM file
+ if not set(references).issubset(set(ref_len_dict.keys())):
+ logging.error(
+ "The BAM file contains references not found in the reference lengths file"
+ )
+ sys.exit(1)
+ else:
+ ref_len_dict = None
+
+ total_refs = samfile.nreferences
+ log.info(f"::: Found {total_refs:,} reference sequences")
+ # logging.info(f"Found {samfile.mapped:,} alignments")
+ log.info(f"::: Removing references with less than {min_read_count} reads...")
+ references_m = {
+ chrom.contig: chrom.mapped
+ for chrom in samfile.get_index_statistics()
+ if chrom.mapped >= min_read_count
+ }
+
+ references = list(references_m.keys())
+
+ if len(references) == 0:
+ log.warning("::: No reference sequences with alignments found in the BAM file")
+ create_empty_output_files(out_files)
+ sys.exit(0)
+
+ log.info(f"::: Keeping {len(references):,} references")
+
+ log.info("::: Creating reference chunks with uniform read amounts...")
+ # ify the number of chunks
+ ref_chunks = sort_keys_by_approx_weight(
+ input_dict=references_m, scale=1.5, num_cores=threads
+ )
+ log.info(f"::: Created {len(ref_chunks):,} chunks")
+ ref_chunks = random.sample(ref_chunks, len(ref_chunks))
+
+ dt.options.progress.enabled = False
+ dt.options.progress.clear_on_success = True
+ dt.options.nthreads = 1
+
+ parms = list(zip([bam] * len(ref_chunks), ref_chunks))
+
+ log.info("::: Extracting reads from BAM file...")
+ if is_debug():
+ data = list(
+ tqdm.tqdm(
+ map(
+ partial(
+ get_bam_data,
+ ref_lengths=ref_len_dict,
+ percid=min_read_ani,
+ min_read_length=min_read_length,
+ threads=4,
+ ),
+ parms,
+ chunksize=1,
+ ),
+ total=len(parms),
+ leave=False,
+ ncols=80,
+ desc="Chunks processed",
+ )
+ )
+ else:
+ p = Pool(
+ threads,
+ )
+ data = list(
+ tqdm.tqdm(
+ p.imap_unordered(
+ partial(
+ get_bam_data,
+ ref_lengths=ref_len_dict,
+ percid=min_read_ani,
+ threads=4,
+ ),
+ parms,
+ chunksize=1,
+ ),
+ total=len(parms),
+ leave=False,
+ ncols=80,
+ desc="Chunks processed",
+ )
+ )
+
+ p.close()
+ p.join()
+
+ dt.options.progress.enabled = True
+ dt.options.progress.clear_on_success = True
+ if threads > 1:
+ dt.options.nthreads = threads - 1
+ else:
+ dt.options.nthreads = 1
+
+ log.info("::: Collecting results...")
+ reads = list()
+ refs = list()
+ empty_df = 0
+ for i in tqdm.tqdm(range(len(data)), total=len(data), leave=False, ncols=80):
+ empty_df += data[i][3]
+ reads.extend(list(data[i][1]))
+ refs.extend(list(data[i][2]))
+ data[i] = data[i][0]
+ log.info(f"::: ::: Found {empty_df:,} references without alignments")
+
+ data = dt.rbind([x for x in data])
+
+ log.info("::: Indexing references...")
+ refs = dt.Frame(list(set(refs)))
+ refs.names = ["subjectId"]
+ refs["sidx"] = dt.Frame(list(range(refs.shape[0])))
+ refs.key = "subjectId"
+
+ log.info("::: Indexing reads...")
+ reads = dt.Frame(list(set(reads)))
+ reads.names = ["queryId"]
+ reads["qidx"] = dt.Frame([(i + refs.shape[0]) for i in range(reads.shape[0])])
+ reads.key = "queryId"
+
+ log.info("::: Combining data...")
+ data = data[:, :, dt.join(reads)]
+ data = data[:, :, dt.join(refs)]
+
+ del data["queryId"]
+ del data["subjectId"]
+ n_alns_0 = data.shape[0]
+ n_reads_0 = reads.shape[0]
+ n_refs_0 = refs.shape[0]
+ log.info(
+ f"::: References: {n_refs_0:,} | Reads: {n_reads_0:,} | Alignments: {n_alns_0:,}"
+ )
+
+ log.info("::: Allocating data structures...")
+ mat = data[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy()
+ data = None
+
+ # Create a zeros array with the same number of rows as 'mat'
+ zeros_array = np.zeros((mat.shape[0], 5))
+
+ # Stack the zeros_array with the original 'mat'
+ m = np.column_stack([mat, zeros_array])
+ zeros_array = None
+
+ dtype = np.dtype(
+ [
+ ("source", "int"),
+ ("subject", "int"),
+ ("var", "float"),
+ ("slen", "int"),
+ ("s_W", "float"),
+ ("prob", "float"),
+ ("iter", "int"),
+ ("n_aln", "int"),
+ ("max_prob", "float"),
+ ]
+ )
+
+ # Convert the unstructured array to structured array
+ m = rf.unstructured_to_structured(m, dtype)
+ gc.collect()
+ log.info("::: Initializing data structures...")
+ init_data = initialize_subject_weights(m)
+ if reassign_iters > 0:
+ log.info(f"::: Reassigning reads with {reassign_iters} iterations")
+ else:
+ log.info("::: Reassigning reads until convergence")
+ no_multimaps = resolve_multimaps(
+ init_data, iters=reassign_iters, scale=reassign_scale
+ )
+
+ n_reads = len(list(set(no_multimaps["source"])))
+ n_refs = len(list(set(no_multimaps["subject"])))
+ n_alns = no_multimaps.shape[0]
+ log.info(
+ f"::: References: {n_refs:,} | Reads: {n_reads:,} | Alignments: {n_alns:,}"
+ )
+ log.info(
+ f'::: Unique mapping reads: {no_multimaps[no_multimaps["n_aln"] == 1].shape[0]:,} | Multimapping reads: {no_multimaps[no_multimaps["n_aln"] > 1].shape[0]:,}'
+ )
+
+ # add this to the array
+ # no_multimaps["n_aln"] = subject_counts_array[no_multimaps["subject"]]
+
+ # log.info(f"::: Removing references with less than {min_read_count} reads...")
+ # no_multimaps = no_multimaps[no_multimaps["n_aln"] >= min_read_count]
+ # log.info(f"{no_multimaps.shape[0]:,} alignments left")
+ log.info("::: Mapping back indices...")
+ if threads > 1:
+ dt.options.nthreads = threads - 1
+ else:
+ dt.options.nthreads = 1
+
+ g = dt.Frame(no_multimaps["source"])
+ g.names = ["qidx"]
+ reads.key = "qidx"
+ q = g[:, :, dt.join(reads)]
+
+ g = dt.Frame(no_multimaps["subject"])
+ g.names = ["sidx"]
+ refs.key = "sidx"
+ s = g[:, :, dt.join(refs)]
+
+ log.info("::: Calculating reads per subject...")
+ # count how many alignments are in each subjectId
+ s_c = s[:, dt.count(dt.f.subjectId), dt.by(dt.f.subjectId)]
+ s_c.names = ["subjectId", "counts"]
+ references_m = dict()
+ log.info(f"::: Removing references with less than {min_read_count:,}...")
+ for i, k in zip(s_c[:, "subjectId"].to_list()[0], s_c[:, "counts"].to_list()[0]):
+ if k >= min_read_count:
+ references_m[i] = k
+ log.info(f"::: ::: Keeping {len(references_m):,} references")
+ s_c = None
+ # convert columns queryId from q and subjectId from s to a tuple
+ log.info("::: Creating filtered set...")
+ entries = defaultdict(set)
+ q_query_ids = q[:, "queryId"].to_list()[0]
+ s_subject_ids = s[:, "subjectId"].to_list()[0]
+
+ for query_id, subject_id in zip(q_query_ids, s_subject_ids):
+ if subject_id in references_m:
+ entries[subject_id].add((query_id, subject_id))
+ no_multimaps = None
+ q = None
+ s = None
+ q_query_ids = None
+ s_subject_ids = None
+ gc.collect()
+ log.info("::: Writing to BAM file...")
+ write_reassigned_bam(
+ bam=bam,
+ ref_counts=references_m,
+ out_files=out_files,
+ threads=threads,
+ entries=entries,
+ sort_memory=sort_memory,
+ min_read_ani=min_read_ani,
+ )
+
+
+def reassign(args):
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format="%(levelname)s ::: %(asctime)s ::: %(message)s",
+ datefmt="%H:%M:%S",
+ )
+
+ args = get_arguments()
+ bam = args.bam
+
+ tmp_dir = check_tmp_dir_exists(args.tmp_dir)
+ log.info("Temporary directory: %s", tmp_dir.name)
+
+ out_files = create_output_files(
+ prefix=args.prefix,
+ bam=args.bam,
+ tmp_dir=tmp_dir,
+ # bam_reassigned=args.bam_reassigned,
+ )
+
+ logging.getLogger("my_logger").setLevel(
+ logging.DEBUG if args.debug else logging.INFO
+ )
+
+ if args.debug:
+ warnings.showwarning = handle_warning
+ else:
+ warnings.filterwarnings("ignore")
+ logging.info("Resolving multi-mapping reads...")
+ reassign_reads(
+ bam=bam,
+ threads=args.threads,
+ reference_lengths=args.reference_lengths,
+ min_read_count=args.min_read_count,
+ min_read_ani=args.min_read_ani,
+ min_read_length=args.min_read_length,
+ reassign_iters=args.reassign_iters,
+ reassign_scale=args.reassign_scale,
+ sort_memory=args.sort_memory,
+ sort_by_name=args.sort_by_name,
+ out_files=out_files,
+ )
+ log.info("Done!")
diff --git a/bam_filter/sam_utils.py b/bam_filter/sam_utils.py
index 222d708..175ecc3 100644
--- a/bam_filter/sam_utils.py
+++ b/bam_filter/sam_utils.py
@@ -3,7 +3,7 @@
import os
import sys
import pandas as pd
-from multiprocessing import Pool
+from multiprocessing import Pool, Manager, cpu_count
import functools
from scipy import stats
import tqdm
@@ -15,11 +15,15 @@
initializer,
create_empty_output_files,
create_empty_bam,
+ sort_keys_by_approx_weight,
)
+import random
from bam_filter.entropy import entropy, norm_entropy, gini_coeff, norm_gini_coeff
from collections import defaultdict
import pyranges as pr
from pathlib import Path
+import concurrent.futures
+import gc
# import cProfile as profile
# import pstats
@@ -249,6 +253,9 @@ def get_bam_stats(
# prof.enable()
bam, references = params
results = []
+
+ if threads > 4:
+ threads = 4
samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
read_hits = defaultdict(int)
@@ -339,8 +346,8 @@ def get_bam_stats(
mean_coverage_trunc, mean_coverage_trunc_len = get_tad(
cov_np,
- trim_min=10,
- trim_max=90,
+ trim_min=trim_min,
+ trim_max=trim_max,
)
cov_pos = cov_np[cov_np > 0]
@@ -806,7 +813,6 @@ def process_bam(
scale=1e6,
plot=False,
plots_dir="coverage-plots",
- chunksize=None,
read_length_freqs=False,
output_files=None,
low_memory=False,
@@ -839,14 +845,10 @@ def process_bam(
"The BAM file contains references not found in the reference lengths file"
)
sys.exit(1)
- sys.exit(1)
total_refs = samfile.nreferences
logging.info(f"Found {total_refs:,} reference sequences")
# logging.info(f"Found {samfile.mapped:,} alignments")
- logging.info(
- f"Removing references without mappings or less than {min_read_count} reads..."
- )
# Remove references without mapped reads
# alns_in_ref = defaultdict(int)
# for aln in tqdm.tqdm(
@@ -862,11 +864,28 @@ def process_bam(
if not os.path.exists(plots_dir):
os.makedirs(plots_dir)
- references = [
- chrom.contig
- for chrom in samfile.get_index_statistics()
- if chrom.mapped >= min_read_count
- ]
+ # references = [
+ # chrom.contig
+ # for chrom in samfile.get_index_statistics()
+ # if chrom.mapped >= min_read_count
+ # ]
+
+ logging.info(f"Removing references with less than {min_read_count} reads...")
+
+ references_m = {
+ chrom.contig: chrom.mapped
+ for chrom in samfile.get_index_statistics()
+ if chrom.mapped >= min_read_count
+ }
+ else:
+ logging.info(f"Removing references with less than {min_read_count} reads...")
+
+ references_m = {
+ chrom.contig: chrom.mapped
+ for chrom in samfile.get_index_statistics()
+ if chrom.mapped >= min_read_count
+ }
+ references = list(references_m.keys())
if len(references) == 0:
logging.warning("No reference sequences with alignments found in the BAM file")
@@ -887,18 +906,17 @@ def process_bam(
)
samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
- if (chunksize is not None) and ((len(references) // chunksize) > threads):
- c_size = chunksize
- else:
- c_size = calc_chunksize(
- n_workers=threads, len_iterable=len(references), factor=4
- )
- ref_chunks = [references[i : i + c_size] for i in range(0, len(references), c_size)]
+ log.info("::: Creating reference chunks with uniform read amounts...")
+
+ # ify the number of chunks
+ ref_chunks = sort_keys_by_approx_weight(
+ input_dict=references_m, scale=1.5, num_cores=threads
+ )
+ log.info(f"::: Created {len(ref_chunks):,} chunks")
+ ref_chunks = random.sample(ref_chunks, len(ref_chunks))
params = zip([bam] * len(ref_chunks), ref_chunks)
try:
- logging.info(
- f"Processing {len(ref_chunks):,} chunks of {c_size:,} references each"
- )
+ logging.info(f"Processing {len(ref_chunks):,} chunks")
if is_debug():
data = list(
map(
@@ -921,8 +939,8 @@ def process_bam(
else:
p = Pool(
threads,
- initializer=initializer,
- initargs=([params, ref_lengths, scale],),
+ # initializer=initializer,
+ # initargs=([params, shared_dict],),
)
data = list(
@@ -962,6 +980,26 @@ def process_bam(
return data
+def write_to_file(alns, out_bam_file, header=None):
+ for aln in alns:
+ out_bam_file.write(pysam.AlignedSegment.fromstring(aln, header))
+
+
+def process_references_batch(references, bam, refs_idx, min_read_ani, threads=1):
+ alns = []
+ with pysam.AlignmentFile(bam, "rb", threads=threads) as samfile:
+ for reference in references:
+ for aln in samfile.fetch(
+ reference=reference, multiple_iterators=False, until_eof=True
+ ):
+ ani_read = (1 - ((aln.get_tag("NM") / aln.infer_query_length()))) * 100
+ if ani_read >= min_read_ani:
+ aln.reference_id = refs_idx[aln.reference_name]
+ alns.append(aln.to_string())
+
+ return alns
+
+
def filter_reference_BAM(
bam,
df,
@@ -1060,15 +1098,27 @@ def filter_reference_BAM(
# prof.enable()
logging.info("Saving filtered stats...")
df_filtered.to_csv(
- out_files["stats_filtered"], sep="\t", index=False, compression="gzip"
+ out_files["stats_filtered"],
+ sep="\t",
+ index=False,
)
if out_files["bam_filtered"] is not None:
logging.info("Writing filtered BAM file... (be patient)")
- refs_dict = dict(
- zip(df_filtered["reference"], df_filtered["bam_reference_length"])
- )
- (ref_names, ref_lengths) = zip(*refs_dict.items())
+ # refs_dict = dict(
+ # zip(df_filtered["reference"], df_filtered["bam_reference_length"])
+ # )
+ references = df_filtered["reference"].tolist()
+ samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
+ refs_dict = {x: samfile.get_reference_length(x) for x in references}
+ header = samfile.header
+ (ref_names, ref_lengths) = zip(*refs_dict.items())
+ ref_dict_m = {
+ chrom.contig: chrom.mapped
+ for chrom in samfile.get_index_statistics()
+ if chrom.contig in ref_names
+ }
+ samfile.close()
refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)}
if threads > 4:
write_threads = 4
@@ -1083,28 +1133,84 @@ def filter_reference_BAM(
threads=write_threads,
)
- samfile = pysam.AlignmentFile(bam, "rb", threads=threads)
- references = [x for x in samfile.references if x in refs_idx.keys()]
-
- logging.info(
- f"::: Filtering {len(references):,} references sequentially..."
+ # logging.info(
+ # f"::: Filtering {len(references):,} references sequentially..."
+ # )
+ # for reference in tqdm.tqdm(
+ # references,
+ # total=len(references),
+ # leave=False,
+ # ncols=80,
+ # desc="References processed",
+ # ):
+ # for aln in samfile.fetch(
+ # reference=reference, multiple_iterators=False, until_eof=True
+ # ):
+ # ani_read = (
+ # 1 - ((aln.get_tag("NM") / aln.infer_query_length()))
+ # ) * 100
+ # if ani_read >= min_read_ani:
+ # aln.reference_id = refs_idx[aln.reference_name]
+ # out_bam_file.write(aln)
+ num_cores = min(threads, cpu_count())
+ # batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size
+ log.info("::: Creating reference chunks with uniform read amounts...")
+ ref_chunks = sort_keys_by_approx_weight(
+ input_dict=ref_dict_m, scale=1.5, num_cores=threads, verbose=False
)
- for reference in tqdm.tqdm(
- references,
- total=len(references),
- leave=False,
- ncols=80,
- desc="References processed",
- ):
- for aln in samfile.fetch(
- reference=reference, multiple_iterators=False, until_eof=True
+ num_cores = min(num_cores, len(ref_chunks))
+ log.info(f"::: Using {num_cores} cores to write {len(ref_chunks)} chunk(s)")
+
+ # with Manager() as manager:
+ # # Use Manager to create a read-only proxy for the dictionary
+ # #refs_idx = manager.dict(refs_idx)
+
+ with concurrent.futures.ProcessPoolExecutor(
+ max_workers=num_cores
+ ) as executor:
+ # Use ProcessPoolExecutor to parallelize the processing of references in batches
+ futures = []
+ for batch_references in tqdm.tqdm(
+ ref_chunks,
+ total=len(ref_chunks),
+ desc="Submitted batches",
+ unit="batch",
+ ncols=80,
+ leave=False,
+ disable=is_debug(),
):
- ani_read = (
- 1 - ((aln.get_tag("NM") / aln.infer_query_length()))
- ) * 100
- if ani_read >= min_read_ani:
- aln.reference_id = refs_idx[aln.reference_name]
- out_bam_file.write(aln)
+ future = executor.submit(
+ process_references_batch,
+ batch_references,
+ bam,
+ refs_idx,
+ min_read_ani=min_read_ani,
+ )
+ futures.append(future) # Store the future
+
+ # Use a while loop to continuously check for completed futures
+ log.info("::: Collecting batches...")
+
+ completion_progress_bar = tqdm.tqdm(
+ total=len(futures),
+ desc="Completed",
+ unit="batch",
+ disable=is_debug(),
+ ncols=80,
+ leave=False,
+ )
+ completed_count = 0
+
+ # Use as_completed to iterate over completed futures as they become available
+ for completed_future in concurrent.futures.as_completed(futures):
+ alns = completed_future.result()
+ write_to_file(alns=alns, out_bam_file=out_bam_file, header=header)
+ # Update the progress bar for each completed write
+ completion_progress_bar.update(1)
+ completed_count += 1
+ completed_future.cancel() # Cancel the future to free memory
+ gc.collect() # Force garbage collection
+ completion_progress_bar.close()
out_bam_file.close()
# prof.disable()
# # print profiling output
@@ -1124,6 +1230,7 @@ def filter_reference_BAM(
out_files["bam_filtered_tmp"],
)
else:
+ logging.info("Sorting BAM file by coordinates...")
pysam.sort(
"-@",
str(threads),
diff --git a/bam_filter/utils.py b/bam_filter/utils.py
index ea7cde7..397cde9 100644
--- a/bam_filter/utils.py
+++ b/bam_filter/utils.py
@@ -16,20 +16,153 @@
import numpy as np
from pathlib import Path
import pysam
+import math
+import tempfile
+
log = logging.getLogger("my_logger")
log.setLevel(logging.INFO)
timestr = time.strftime("%Y%m%d-%H%M%S")
+def handle_warning(message, category, filename, lineno, file=None, line=None):
+ print("A warning occurred:")
+ print(message)
+ print("Do you wish to continue?")
+
+ while True:
+ response = input("y/n: ").lower()
+ if response not in {"y", "n"}:
+ print("Not understood.")
+ else:
+ break
+
+ if response == "n":
+ raise category(message)
+
+
+# Check if the temporary directory exists, if not, create it
+def check_tmp_dir_exists(tmpdir):
+ if tmpdir is None:
+ tmpdir = tempfile.TemporaryDirectory(dir=os.getcwd())
+ else:
+ if not os.path.exists(tmpdir):
+ log.error(f"Temporary directory {tmpdir} does not exist")
+ exit(1)
+ tmpdir = tempfile.TemporaryDirectory(dir=os.path.abspath(tmpdir))
+ return tmpdir
+
+
def is_debug():
return logging.getLogger("my_logger").getEffectiveLevel() == logging.DEBUG
+def sort_keys_by_approx_weight(
+ input_dict, scale=1, num_cores=1, refinement_steps=10, verbose=False
+):
+ # Check for division by zero
+ if scale == 0:
+ raise ValueError("Scale cannot be zero.")
+
+ # Calculate the target weight for each chunk
+ target_weight = int(scale * max(input_dict.values()))
+
+ # Sort keys by their weights in descending order
+ sorted_keys = sorted(input_dict, key=lambda k: input_dict[k], reverse=True)
+
+ # Calculate total weight
+ total_weight = sum(input_dict.values())
+
+ # Calculate the number of chunks based on the target weight
+ num_chunks = max(1, math.ceil(total_weight / target_weight))
+
+ # Initialize chunks with their total weights
+ chunks = [[] for _ in range(num_chunks)]
+ total_weights = [0] * num_chunks
+
+ # Create a progress bar
+ progress_bar = tqdm.tqdm(
+ total=len(sorted_keys),
+ desc="Distributing keys",
+ unit="k",
+ unit_scale=True,
+ unit_divisor=1000,
+ disable=False, # Replace with your logic or a boolean value
+ leave=False,
+ ncols=80,
+ )
+
+ # Distribute keys into chunks
+ for key in sorted_keys:
+ min_chunk = min(range(num_chunks), key=lambda i: total_weights[i])
+ chunks[min_chunk].append(key)
+ total_weights[min_chunk] += input_dict[key]
+
+ # Update the progress bar
+ progress_bar.update(1)
+
+ # Close the progress bar
+ progress_bar.close()
+
+ # Remove empty chunks
+ chunks = [chunk for chunk in chunks if chunk]
+
+ # Ensure a balanced number of reads in each chunk
+ max_reads = max(len(chunk) for chunk in chunks)
+ for chunk in chunks:
+ while len(chunk) < max_reads:
+ # Add elements from other chunks if needed
+ other_chunk = next((c for c in chunks if c != chunk and c), None)
+ if other_chunk:
+ element = other_chunk.pop(0)
+ chunk.append(element)
+
+ # Initial balance
+ initial_balance = max(total_weights) - min(total_weights)
+
+ # Refinement step
+ for _ in range(refinement_steps):
+ # Sort chunks by their total weights in ascending order
+ sorted_chunks = sorted(range(num_chunks), key=lambda i: total_weights[i])
+
+ # Move keys between chunks to minimize the difference in total weights
+ for src_chunk in sorted_chunks[:-1]:
+ dest_chunk = sorted_chunks[-1]
+ if total_weights[dest_chunk] - total_weights[src_chunk] > 1:
+ # Move one key from dest_chunk to src_chunk
+ key_to_move = chunks[dest_chunk].pop(0)
+ chunks[src_chunk].append(key_to_move)
+ total_weights[src_chunk] += input_dict[key_to_move]
+ total_weights[dest_chunk] -= input_dict[key_to_move]
+
+ # Check for improvement in balance
+ current_balance = max(total_weights) - min(total_weights)
+ if current_balance >= initial_balance:
+ break # No improvement, exit the loop
+
+ # Update initial balance for the next iteration
+ initial_balance = current_balance
+
+ chunks = [chunk for chunk in chunks if chunk]
+
+ # Print the min, max, and average weight of each chunk
+ if verbose:
+ for i, chunk in enumerate(chunks, 1):
+ chunk_weights = [input_dict[key] for key in chunk]
+ min_weight = min(chunk_weights)
+ max_weight = max(chunk_weights)
+ avg_weight = sum(chunk_weights) / len(chunk_weights)
+ print(
+ f"Chunk {i}: Total = {sum(chunk_weights)}, Min Weight = {min_weight}, Max Weight = {max_weight}, Average Weight = {avg_weight}"
+ )
+
+ return chunks
+
+
def create_empty_output_files(out_files):
for key, value in out_files.items():
if value is not None:
- if key == "bam_filtered":
+ if key == "bam_filtered" or key == "bam_reassigned":
create_empty_bam(value)
elif (
key == "bam_filtered_tmp" or key == "bam_tmp" or key == "bam_tmp_sorted"
@@ -162,6 +295,78 @@ def is_valid_file(parser, arg, var):
return arg
+# From https://stackoverflow.com/a/59617044/15704171
+def convert_list_to_str(lst):
+ n = len(lst)
+ if not n:
+ return ""
+ if n == 1:
+ return lst[0]
+ return ", ".join(lst[:-1]) + f" or {lst[-1]}"
+
+
+lca_ranks = [
+ "superkingdom",
+ "domain",
+ "lineage",
+ "kingdom",
+ "subkingdom",
+ "superphylum",
+ "phylum",
+ "subphylum",
+ "superclass",
+ "class",
+ "subclass",
+ "infraclass",
+ "clade",
+ "cohort",
+ "subcohort",
+ "superorder",
+ "order",
+ "suborder",
+ "infraorder",
+ "parvorder",
+ "superfamily",
+ "family",
+ "subfamily",
+ "tribe",
+ "subtribe",
+ "infratribe",
+ "genus",
+ "subgenus",
+ "section",
+ "series",
+ "subseries",
+ "subsection",
+ "species",
+ "species group",
+ "species subgroup",
+ "subspecies",
+ "varietas",
+ "morph",
+ "subvariety",
+ "forma",
+ "forma specialis",
+ "biotype",
+ "genotype",
+ "isolate",
+ "pathogroup",
+ "serogroup",
+ "serotype",
+ "strain",
+]
+
+
+def check_lca_ranks(val, parser, var):
+ value = str(val)
+ if value in lca_ranks:
+ return value
+ else:
+ parser.error(
+ f"argument {var}: Invalid value {value}. Filter has to be one of {convert_list_to_str(lca_ranks)}"
+ )
+
+
defaults = {
"min_read_length": 30,
"min_read_count": 3,
@@ -183,10 +388,16 @@ def is_valid_file(parser, arg, var):
"stats": None,
"stats_filtered": None,
"bam_filtered": None,
+ "bam_reassigned": None,
"knee_plot": None,
"read_length_freqs": None,
"read_hits_count": None,
"tmp_dir": None,
+ "reassign_iters": 25,
+ "reassign_scale": 0.9,
+ "reassign_target": "mapq",
+ "rank_lca": "species",
+ "lca_summary": None,
}
help_msg = {
@@ -215,6 +426,7 @@ def is_valid_file(parser, arg, var):
"stats": "Save a TSV file with the statistics for each reference",
"stats_filtered": "Save a TSV file with the statistics for each reference after filtering",
"bam_filtered": "Save a BAM file with the references that passed the filtering criteria",
+ "bam_reassigned": "Save a BAM file without multimapping reads",
"coverage_plots": "Folder where to save genome coverage plots",
"knee_plot": "Plot knee plot",
"sort_by_name": "Sort by read names",
@@ -225,6 +437,17 @@ def is_valid_file(parser, arg, var):
"debug": "Print debug messages",
"reference_lengths": "File with references lengths",
"low_memory": "Activate the low memory mode",
+ "reassign": "Run an EM algorithm to reassign reads to references",
+ "reassign_method": "Method for the EM algorithm",
+ "reassign_iters": "Number of iterations for the EM algorithm",
+ "reassign_scale": "Scale to select the best weithing alignments",
+ "reassign_target": "Which target to use for the EM algorith, Only mapq or pident.",
+ "lca": "Calculate LCA for each read and estimate abundances",
+ "names": "Names dmp file from taxonomy",
+ "nodes": "Nodes dmp file from taxonomy",
+ "acc2taxid": "acc2taxid file from taxonomy",
+ "rank_lca": "Rank to use for LCA calculation",
+ "lca_summary": "Save a TSV file with the LCA summary",
"version": "Print program version",
}
@@ -234,18 +457,53 @@ def get_arguments(argv=None):
description="A simple tool to calculate metrics from a BAM file and filter with uneven coverage.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
- # add subparser for filtering options:
- filter_args = parser.add_argument_group("filtering arguments")
- misc_args = parser.add_argument_group("miscellaneous arguments")
- out_args = parser.add_argument_group("output arguments")
parser.add_argument(
+ "--version",
+ action="version",
+ version="%(prog)s " + __version__,
+ help=help_msg["version"],
+ )
+ parser.add_argument(
+ "--debug", dest="debug", action="store_true", help=help_msg["debug"]
+ )
+
+ sub_parsers = parser.add_subparsers(
+ help="positional arguments",
+ dest="action",
+ )
+
+ # Create parent subparser. Note `add_help=False` and creation via `argparse.`
+ parent_parser = argparse.ArgumentParser(add_help=False)
+
+ required = parent_parser.add_argument_group("required arguments")
+ required.add_argument(
"--bam",
required=True,
dest="bam",
type=lambda x: is_valid_file(parser, x, "bam"),
help=help_msg["bam"],
)
- parser.add_argument(
+ # required = parent_parser.add_argument_group("required arguments")
+ optional = parent_parser.add_argument_group("optional arguments")
+ optional.add_argument(
+ "-p",
+ "--prefix",
+ type=str,
+ default=defaults["prefix"],
+ metavar="STR",
+ dest="prefix",
+ help=help_msg["prefix"],
+ )
+ optional.add_argument(
+ "-r",
+ "--reference-lengths",
+ type=lambda x: is_valid_file(parser, x, "reference_lengths"),
+ metavar="FILE",
+ default=defaults["reference_lengths"],
+ dest="reference_lengths",
+ help=help_msg["reference_lengths"],
+ )
+ optional.add_argument(
"-t",
"--threads",
type=lambda x: int(
@@ -256,7 +514,155 @@ def get_arguments(argv=None):
default=1,
help=help_msg["threads"],
)
- misc_args.add_argument(
+ # Create the parser sub-command for db creation
+ parser_reassign = sub_parsers.add_parser(
+ "reassign",
+ help="Reassign reads to references using an EM algorithm",
+ parents=[parent_parser],
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ # create the parser sub-commands
+ parser_filter = sub_parsers.add_parser(
+ "filter",
+ help="Filter references based on coverage and other metrics",
+ parents=[parent_parser],
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser_lca = sub_parsers.add_parser(
+ "lca",
+ help="Calculate LCA for each read and estimate abundances at each rank",
+ parents=[parent_parser],
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # createdb_required_args = parser_createdb.add_argument_group("required arguments")
+ # reassign_required_args = parser_reassign.add_argument_group(
+ # "Re-assign required arguments"
+ # )
+ reassign_optional_args = parser_reassign.add_argument_group(
+ "Re-assign optional arguments"
+ )
+
+ filter_required_args = parser_filter.add_argument_group("Filter required arguments")
+ # filter_optional_args = parser_filter.add_argument_group("Filter optional arguments")
+
+ # lca_required_args = parser_lca.add_argument_group("LCA required arguments")
+ lca_optional_args = parser_lca.add_argument_group("LCA optional arguments")
+
+ # add subparser for filtering options:
+ # reassign_args = parser.add_argument_group("reassign arguments")
+ filtering_filt_args = parser_filter.add_argument_group("filtering arguments")
+ # lca_args = parser.add_argument_group("lca arguments")
+ misc_filter_args = parser_filter.add_argument_group("miscellaneous arguments")
+ out_filter_args = parser_filter.add_argument_group("output arguments")
+ # parser.add_argument(
+ # "--bam",
+ # required=True,
+ # dest="bam",
+ # type=lambda x: is_valid_file(parser, x, "bam"),
+ # help=help_msg["bam"],
+ # )
+ # parser.add_argument(
+ # "-t",
+ # "--threads",
+ # type=lambda x: int(
+ # check_values(x, minval=1, maxval=1000, parser=parser, var="--threads")
+ # ),
+ # dest="threads",
+ # metavar="INT",
+ # default=1,
+ # help=help_msg["threads"],
+ # )
+
+ reassign_optional_args.add_argument(
+ "-i",
+ "--iters",
+ type=lambda x: int(
+ check_values(
+ x, minval=0, maxval=100000, parser=parser, var="--reassign-n-iters"
+ )
+ ),
+ metavar="INT",
+ default=defaults["reassign_iters"],
+ dest="reassign_iters",
+ help=help_msg["reassign_iters"],
+ )
+ reassign_optional_args.add_argument(
+ "-s",
+ "--scale",
+ type=lambda x: float(
+ check_values(x, minval=0, maxval=1, parser=parser, var="--scale")
+ ),
+ metavar="FLOAT",
+ default=defaults["reassign_scale"],
+ dest="reassign_scale",
+ help=help_msg["reassign_scale"],
+ )
+ reassign_optional_args.add_argument(
+ "-A",
+ "--min-read-ani",
+ type=lambda x: float(
+ check_values(x, minval=0, maxval=100, parser=parser, var="--min-read-ani")
+ ),
+ metavar="FLOAT",
+ default=defaults["min_read_ani"],
+ dest="min_read_ani",
+ help=help_msg["min_read_ani"],
+ )
+ reassign_optional_args.add_argument(
+ "-l",
+ "--min-read-length",
+ type=lambda x: int(
+ check_values(
+ x, minval=1, maxval=100000, parser=parser, var="--min-read-length"
+ )
+ ),
+ default=defaults["min_read_length"],
+ metavar="INT",
+ dest="min_read_length",
+ help=help_msg["min_read_length"],
+ )
+ reassign_optional_args.add_argument(
+ "-n",
+ "--min-read-count",
+ type=lambda x: int(
+ check_values(
+ x, minval=1, maxval=np.Inf, parser=parser, var="--min-read-count"
+ )
+ ),
+ default=defaults["min_read_count"],
+ metavar="INT",
+ dest="min_read_count",
+ help=help_msg["min_read_count"],
+ )
+ reassign_optional_args.add_argument(
+ "-o",
+ "--out-bam",
+ dest="bam_reassigned",
+ default=defaults["bam_reassigned"],
+ metavar="FILE",
+ type=str,
+ nargs="?",
+ const="",
+ help=help_msg["bam_reassigned"],
+ )
+ reassign_optional_args.add_argument(
+ "-m",
+ "--sort-memory",
+ type=lambda x: check_suffix(x, parser=parser, var="--sort-memory"),
+ default=defaults["sort_memory"],
+ metavar="STR",
+ dest="sort_memory",
+ help=help_msg["sort_memory"],
+ )
+ reassign_optional_args.add_argument(
+ "-N",
+ "--sort-by-name",
+ dest="sort_by_name",
+ action="store_true",
+ help=help_msg["sort_by_name"],
+ )
+ misc_filter_args.add_argument(
"--reference-trim-length",
type=lambda x: int(
check_values(
@@ -268,7 +674,7 @@ def get_arguments(argv=None):
default=0,
help=help_msg["trim_ends"],
)
- misc_args.add_argument(
+ misc_filter_args.add_argument(
"--trim-min",
type=lambda x: int(
check_values(x, minval=0, maxval=100, parser=parser, var="--trim-min")
@@ -278,7 +684,7 @@ def get_arguments(argv=None):
default=10,
help=help_msg["trim_min"],
)
- misc_args.add_argument(
+ misc_filter_args.add_argument(
"--trim-max",
type=lambda x: int(
check_values(x, minval=0, maxval=100, parser=parser, var="--trim-max")
@@ -288,16 +694,7 @@ def get_arguments(argv=None):
default=90,
help=help_msg["trim_max"],
)
- parser.add_argument(
- "-p",
- "--prefix",
- type=str,
- default=defaults["prefix"],
- metavar="STR",
- dest="prefix",
- help=help_msg["prefix"],
- )
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-A",
"--min-read-ani",
type=lambda x: float(
@@ -308,7 +705,7 @@ def get_arguments(argv=None):
dest="min_read_ani",
help=help_msg["min_read_ani"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-l",
"--min-read-length",
type=lambda x: int(
@@ -321,7 +718,7 @@ def get_arguments(argv=None):
dest="min_read_length",
help=help_msg["min_read_length"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-n",
"--min-read-count",
type=lambda x: int(
@@ -334,7 +731,7 @@ def get_arguments(argv=None):
dest="min_read_count",
help=help_msg["min_read_count"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-b",
"--min-expected-breadth-ratio",
type=lambda x: float(
@@ -347,7 +744,7 @@ def get_arguments(argv=None):
dest="min_expected_breadth_ratio",
help=help_msg["min_expected_breadth_ratio"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-e",
"--min-normalized-entropy",
type=lambda x: check_values_auto(
@@ -358,7 +755,7 @@ def get_arguments(argv=None):
dest="min_norm_entropy",
help=help_msg["min_norm_entropy"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-g",
"--min-normalized-gini",
type=lambda x: check_values_auto(
@@ -369,7 +766,7 @@ def get_arguments(argv=None):
dest="min_norm_gini",
help=help_msg["min_norm_gini"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-B",
"--min-breadth",
type=lambda x: float(
@@ -380,7 +777,7 @@ def get_arguments(argv=None):
dest="min_breadth",
help=help_msg["min_breadth"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-a",
"--min-avg-read-ani",
type=lambda x: float(
@@ -393,7 +790,7 @@ def get_arguments(argv=None):
dest="min_avg_read_ani",
help=help_msg["min_avg_read_ani"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-c",
"--min-coverage-evenness",
type=lambda x: float(
@@ -406,7 +803,7 @@ def get_arguments(argv=None):
dest="min_coverage_evenness",
help=help_msg["min_coverage_evenness"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-V",
"--min-coeff-var",
type=lambda x: float(
@@ -419,7 +816,7 @@ def get_arguments(argv=None):
dest="min_coeff_var",
help=help_msg["min_coeff_var"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"-C",
"--min-coverage-mean",
type=lambda x: float(
@@ -432,13 +829,13 @@ def get_arguments(argv=None):
dest="min_coverage_mean",
help=help_msg["min_coverage_mean"],
)
- filter_args.add_argument(
+ filtering_filt_args.add_argument(
"--include-low-detection",
dest="transform_cov_evenness",
action="store_true",
help=help_msg["transform_cov_evenness"],
)
- parser.add_argument(
+ misc_filter_args.add_argument(
"-m",
"--sort-memory",
type=lambda x: check_suffix(x, parser=parser, var="--sort-memory"),
@@ -447,20 +844,20 @@ def get_arguments(argv=None):
dest="sort_memory",
help=help_msg["sort_memory"],
)
- parser.add_argument(
+ misc_filter_args.add_argument(
"-N",
"--sort-by-name",
dest="sort_by_name",
action="store_true",
help=help_msg["sort_by_name"],
)
- parser.add_argument(
+ misc_filter_args.add_argument(
"--disable-sort",
dest="disable_sort",
action="store_true",
help=help_msg["disable_sort"],
)
- parser.add_argument(
+ misc_filter_args.add_argument(
"--scale",
type=lambda x: check_suffix(x, parser=parser, var="--scale"),
default=defaults["scale"],
@@ -469,16 +866,16 @@ def get_arguments(argv=None):
help=help_msg["scale"],
)
# reference_lengths
- parser.add_argument(
- "-r",
- "--reference-lengths",
- type=lambda x: is_valid_file(parser, x, "reference_lengths"),
- metavar="FILE",
- default=defaults["reference_lengths"],
- dest="reference_lengths",
- help=help_msg["reference_lengths"],
- )
- out_args.add_argument(
+ # filter_optional_args.add_argument(
+ # "-r",
+ # "--reference-lengths",
+ # type=lambda x: is_valid_file(parser, x, "reference_lengths"),
+ # metavar="FILE",
+ # default=defaults["reference_lengths"],
+ # dest="reference_lengths",
+ # help=help_msg["reference_lengths"],
+ # )
+ filter_required_args.add_argument(
"--stats",
dest="stats",
default=defaults["stats"],
@@ -489,7 +886,7 @@ def get_arguments(argv=None):
required=True,
help=help_msg["stats"],
)
- out_args.add_argument(
+ out_filter_args.add_argument(
"--stats-filtered",
dest="stats_filtered",
default=defaults["stats_filtered"],
@@ -499,7 +896,7 @@ def get_arguments(argv=None):
const="",
help=help_msg["stats_filtered"],
)
- out_args.add_argument(
+ out_filter_args.add_argument(
"--bam-filtered",
dest="bam_filtered",
default=defaults["bam_filtered"],
@@ -509,7 +906,7 @@ def get_arguments(argv=None):
const="",
help=help_msg["bam_filtered"],
)
- out_args.add_argument(
+ out_filter_args.add_argument(
"--read-length-freqs",
dest="read_length_freqs",
default=defaults["read_length_freqs"],
@@ -519,7 +916,7 @@ def get_arguments(argv=None):
const="",
help=help_msg["read_length_freqs"],
)
- out_args.add_argument(
+ out_filter_args.add_argument(
"--read-hits-count",
dest="read_hits_count",
default=defaults["read_hits_count"],
@@ -529,7 +926,7 @@ def get_arguments(argv=None):
const="",
help=help_msg["read_hits_count"],
)
- out_args.add_argument(
+ out_filter_args.add_argument(
"--knee-plot",
dest="knee_plot",
default=defaults["knee_plot"],
@@ -539,7 +936,7 @@ def get_arguments(argv=None):
const="",
help=help_msg["knee_plot"],
)
- out_args.add_argument(
+ out_filter_args.add_argument(
"--coverage-plots",
dest="coverage_plots",
metavar="FILE",
@@ -549,17 +946,17 @@ def get_arguments(argv=None):
const="",
help=help_msg["coverage_plots"],
)
- parser.add_argument(
- "--chunk-size",
- type=lambda x: int(
- check_values(x, minval=1, maxval=100000, parser=parser, var="--chunk-size")
- ),
- default=defaults["chunk_size"],
- metavar="INT",
- dest="chunk_size",
- help=help_msg["chunk_size"],
- )
- parser.add_argument(
+ # parser.add_argument(
+ # "--chunk-size",
+ # type=lambda x: int(
+ # check_values(x, minval=1, maxval=100000, parser=parser, var="--chunk-size")
+ # ),
+ # default=defaults["chunk_size"],
+ # metavar="INT",
+ # dest="chunk_size",
+ # help=help_msg["chunk_size"],
+ # )
+ misc_filter_args.add_argument(
"--tmp-dir",
type=str,
default=defaults["tmp_dir"],
@@ -567,20 +964,59 @@ def get_arguments(argv=None):
dest="tmp_dir",
help=help_msg["tmp_dir"],
)
- parser.add_argument(
+ misc_filter_args.add_argument(
"--low-memory",
dest="low_memory",
action="store_true",
help=help_msg["low_memory"],
)
- parser.add_argument(
- "--debug", dest="debug", action="store_true", help=help_msg["debug"]
+
+ lca_optional_args.add_argument(
+ "--names",
+ metavar="FILE",
+ type=lambda x: is_valid_file(parser, x, "names"),
+ dest="names",
+ help=help_msg["names"],
)
- parser.add_argument(
- "--version",
- action="version",
- version="%(prog)s " + __version__,
- help=help_msg["version"],
+ lca_optional_args.add_argument(
+ "--nodes",
+ metavar="FILE",
+ type=lambda x: is_valid_file(parser, x, "nodes"),
+ dest="nodes",
+ help=help_msg["nodes"],
+ )
+ lca_optional_args.add_argument(
+ "--acc2taxid",
+ metavar="FILE",
+ type=lambda x: is_valid_file(parser, x, "acc2taxid"),
+ dest="acc2taxid",
+ help=help_msg["acc2taxid"],
+ )
+ lca_optional_args.add_argument(
+ "--rank-lca",
+ metavar="STR",
+ type=lambda x: str(check_lca_ranks(x, parser=parser, var="--rank-lca")),
+ default=defaults["rank_lca"],
+ dest="rank_lca",
+ help=help_msg["rank_lca"],
+ )
+ lca_optional_args.add_argument(
+ "--lca-summary",
+ dest="lca_summary",
+ metavar="FILE",
+ default=defaults["lca_summary"],
+ type=str,
+ nargs="?",
+ const="",
+ help=help_msg["lca_summary"],
+ )
+ lca_optional_args.add_argument(
+ "--scale",
+ type=lambda x: check_suffix(x, parser=parser, var="--scale"),
+ default=defaults["scale"],
+ dest="scale",
+ metavar="STR",
+ help=help_msg["scale"],
)
args = parser.parse_args(None if sys.argv[1:] else ["-h"])
return args
@@ -711,46 +1147,59 @@ def calc_chunksize(n_workers, len_iterable, factor=4):
# }
# return out_files
def create_output_files(
- prefix,
bam,
- stats,
- stats_filtered,
- bam_filtered,
- read_length_freqs,
- read_hits_count,
- knee_plot,
- coverage_plots,
tmp_dir,
+ prefix=None,
+ stats="",
+ stats_filtered="",
+ bam_reassigned="",
+ bam_filtered="",
+ read_length_freqs="",
+ read_hits_count="",
+ knee_plot="",
+ coverage_plots="",
+ lca_summary="",
):
if prefix is None:
prefix = Path(bam).with_suffix("").name
- if stats == "":
+ if tmp_dir is not None:
+ tmp_dir = tmp_dir.name
+
+ if stats == "" or stats is None:
stats = f"{prefix}_stats.tsv.gz"
- if stats_filtered == "":
+ if stats_filtered == "" or stats_filtered is None:
stats_filtered = f"{prefix}_stats-filtered.tsv.gz"
- if bam_filtered == "":
+ if bam_filtered == "" or bam_filtered is None:
bam_filtered = f"{prefix}.filtered.bam"
- if read_length_freqs == "":
+ if bam_reassigned == "" or bam_reassigned is None:
+ bam_reassigned = f"{prefix}.reassigned.bam"
+ if read_length_freqs == "" or read_length_freqs is None:
read_length_freqs = f"{prefix}_read-length-freqs.json"
- if read_hits_count == "":
+ if read_hits_count == "" or read_hits_count is None:
read_hits_count = f"{prefix}_read-hits-count.tsv.gz"
- if knee_plot == "":
+ if knee_plot == "" or knee_plot is None:
knee_plot = f"{prefix}_knee-plot.png"
- if coverage_plots == "":
+ if coverage_plots == "" or coverage_plots is None:
coverage_plots = f"{prefix}_coverage-plots"
+ if lca_summary == "" or lca_summary is None:
+ lca_summary = f"{prefix}_lca-summary.tsv.gz"
# create output files
out_files = {
"stats": stats,
"stats_filtered": stats_filtered,
- "bam_filtered_tmp": f"{tmp_dir.name}/{prefix}.filtered.tmp.bam",
- "bam_tmp": f"{tmp_dir.name}/{prefix}.tmp.bam",
- "bam_tmp_sorted": f"{tmp_dir.name}/{prefix}.tmp.sorted.bam",
+ "bam_filtered_tmp": f"{tmp_dir}/{prefix}.filtered.tmp.bam",
+ "bam_tmp": f"{tmp_dir}/{prefix}.tmp.bam",
+ "bam_tmp_sorted": f"{tmp_dir}/{prefix}.tmp.sorted.bam",
"bam_filtered": bam_filtered,
+ "bam_reassigned_tmp": f"{tmp_dir}/{prefix}.reassigned.tmp.bam",
+ "bam_reassigned_sorted": f"{tmp_dir}/{prefix}.reassigned.sorted.bam",
+ "bam_reassigned": bam_reassigned,
"read_length_freqs": read_length_freqs,
"read_hits_count": read_hits_count,
"knee_plot": knee_plot,
"coverage_plot_dir": coverage_plots,
+ "lca_summary": lca_summary,
}
return out_files