diff --git a/bam_filter/__main__.py b/bam_filter/__main__.py index c5dc21e..cd5f986 100644 --- a/bam_filter/__main__.py +++ b/bam_filter/__main__.py @@ -1,3 +1,4 @@ +# """ This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the @@ -11,271 +12,35 @@ see . """ - import logging -import pandas as pd -from bam_filter.sam_utils import process_bam, filter_reference_BAM, check_bam_file + +from bam_filter.reassign import reassign +from bam_filter.filter import filter_references +from bam_filter.lca import do_lca from bam_filter.utils import ( get_arguments, - create_output_files, - concat_df, ) -from bam_filter.entropy import find_knee -import json -import warnings -from collections import Counter -from functools import reduce -import os -import tempfile log = logging.getLogger("my_logger") -def handle_warning(message, category, filename, lineno, file=None, line=None): - print("A warning occurred:") - print(message) - print("Do you wish to continue?") - - while True: - response = input("y/n: ").lower() - if response not in {"y", "n"}: - print("Not understood.") - else: - break - - if response == "n": - raise category(message) - - -def obj_dict(obj): - return obj.__dict__ - - -def get_summary(obj): - return obj.to_summary() - - -def get_lens(obj): - return obj.get_read_length_freqs() - - -# Check if the temporary directory exists, if not, create it -def check_tmp_dir_exists(tmpdir): - if tmpdir is None: - tmpdir = tempfile.TemporaryDirectory(dir=os.getcwd()) - else: - if not os.path.exists(tmpdir): - log.error(f"Temporary directory {tmpdir} does not exist") - exit(1) - tmpdir = tempfile.TemporaryDirectory(dir=os.path.abspath(tmpdir)) - return tmpdir - - def main(): logging.basicConfig( level=logging.DEBUG, format="%(levelname)s ::: %(asctime)s ::: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) - + logging.getLogger("urllib3").setLevel(logging.WARNING) args = get_arguments() - - tmp_dir = check_tmp_dir_exists(args.tmp_dir) - log.info("Temporary directory: %s", tmp_dir.name) - if args.trim_min >= args.trim_max: - log.error("trim_min must be less than trim_max") - exit(1) - logging.getLogger("my_logger").setLevel( logging.DEBUG if args.debug else logging.INFO ) - - if args.debug: - warnings.showwarning = handle_warning - else: - warnings.filterwarnings("ignore") - - plot_coverage = False - if args.coverage_plots is not None: - plot_coverage = True - - if args.stats_filtered is None and (args.bam_filtered is not None): - logging.error( - "You need to specify a filtereds stats file to obtain the filtered BAM file" - ) - exit(1) - - out_files = create_output_files( - prefix=args.prefix, - bam=args.bam, - stats=args.stats, - stats_filtered=args.stats_filtered, - bam_filtered=args.bam_filtered, - read_length_freqs=args.read_length_freqs, - read_hits_count=args.read_hits_count, - knee_plot=args.knee_plot, - coverage_plots=args.coverage_plots, - tmp_dir=tmp_dir, - ) - - bam = check_bam_file( - bam=args.bam, - threads=args.threads, - reference_lengths=args.reference_lengths, - sort_memory=args.sort_memory, - ) - - data = process_bam( - bam=bam, - threads=args.threads, - reference_lengths=args.reference_lengths, - min_read_count=args.min_read_count, - min_read_ani=args.min_read_ani, - trim_ends=args.trim_ends, - trim_min=args.trim_min, - trim_max=args.trim_max, - scale=args.scale, - plot=plot_coverage, - plots_dir=out_files["coverage_plot_dir"], - chunksize=args.chunk_size, - read_length_freqs=args.read_length_freqs, - output_files=out_files, - sort_memory=args.sort_memory, - low_memory=args.low_memory, - ) - if args.low_memory: - bam = out_files["bam_tmp_sorted"] - - logging.info("Reducing results to a single dataframe") - # data = list(filter(None, data)) - data_df = [x[0] for x in data if x[0] is not None] - data_df = concat_df(data_df) - - if args.read_length_freqs is not None: - logging.info("Calculating read length frequencies...") - lens = [x[1] for x in data if x[1] is not None] - lens = json.dumps(lens, default=obj_dict, ensure_ascii=False, indent=4) - with open(out_files["read_length_freqs"], "w", encoding="utf-8") as outfile: - print(lens, file=outfile) - - if args.read_hits_count is not None: - logging.info("Calculating read hits counts...") - hits = [x[2] for x in data if x[2] is not None] - - # merge dicts and sum values - hits = reduce(lambda x, y: x.update(y) or x, (Counter(dict(x)) for x in hits)) - # hits = sum(map(Counter, hits), Counter()) - - # convert dict to dataframe - hits = ( - pd.DataFrame.from_dict(hits, orient="index", columns=["count"]) - .rename_axis("read") - .reset_index() - .sort_values(by="count", ascending=False) - ) - - hits.to_csv( - out_files["read_hits_count"], sep="\t", index=False, compression="gzip" - ) - - logging.info(f"Writing reference statistics to {out_files['stats']}") - data_df.to_csv(out_files["stats"], sep="\t", index=False, compression="gzip") - - if args.min_norm_entropy is None or args.min_norm_gini is None: - filter_conditions = { - "min_read_length": args.min_read_length, - "min_read_count": args.min_read_count, - "min_expected_breadth_ratio": args.min_expected_breadth_ratio, - "min_breadth": args.min_breadth, - "min_avg_read_ani": args.min_avg_read_ani, - "min_coverage_evenness": args.min_coverage_evenness, - "min_coeff_var": args.min_coeff_var, - "min_coverage_mean": args.min_coverage_mean, - } - elif args.min_norm_entropy == "auto" or args.min_norm_gini == "auto": - if data_df.shape[0] > 1: - min_norm_gini, min_norm_entropy = find_knee( - data_df, out_plot_name=out_files["knee_plot"] - ) - - if min_norm_gini is None or min_norm_entropy is None: - logging.warning( - "Could not find knee in entropy plot. Disabling filtering by entropy/gini inequality." - ) - filter_conditions = { - "min_read_length": args.min_read_length, - "min_read_count": args.min_read_count, - "min_expected_breadth_ratio": args.min_expected_breadth_ratio, - "min_breadth": args.min_breadth, - "min_avg_read_ani": args.min_avg_read_ani, - "min_coverage_evenness": args.min_coverage_evenness, - "min_coeff_var": args.min_coeff_var, - "min_coverage_mean": args.min_coverage_mean, - } - else: - filter_conditions = { - "min_read_length": args.min_read_length, - "min_read_count": args.min_read_count, - "min_expected_breadth_ratio": args.min_expected_breadth_ratio, - "min_breadth": args.min_breadth, - "min_avg_read_ani": args.min_avg_read_ani, - "min_coverage_evenness": args.min_coverage_evenness, - "min_coeff_var": args.min_coeff_var, - "min_coverage_mean": args.min_coverage_mean, - "min_norm_entropy": min_norm_entropy, - "min_norm_gini": min_norm_gini, - } - else: - min_norm_gini = 0.5 - min_norm_entropy = 0.75 - logging.warning( - f"There's only one genome. Using min_norm_gini={min_norm_gini} and min_norm_entropy={min_norm_entropy}. Please check the results." - ) - filter_conditions = { - "min_read_length": args.min_read_length, - "min_read_count": args.min_read_count, - "min_expected_breadth_ratio": args.min_expected_breadth_ratio, - "min_breadth": args.min_breadth, - "min_avg_read_ani": args.min_avg_read_ani, - "min_coverage_evenness": args.min_coverage_evenness, - "min_coeff_var": args.min_coeff_var, - "min_coverage_mean": args.min_coverage_mean, - "min_norm_entropy": min_norm_entropy, - "min_norm_gini": min_norm_gini, - } - else: - min_norm_gini, min_norm_entropy = args.min_norm_gini, args.min_norm_entropy - filter_conditions = { - "min_read_length": args.min_read_length, - "min_read_count": args.min_read_count, - "min_expected_breadth_ratio": args.min_expected_breadth_ratio, - "min_breadth": args.min_breadth, - "min_avg_read_ani": args.min_avg_read_ani, - "min_coverage_evenness": args.min_coverage_evenness, - "min_coeff_var": args.min_coeff_var, - "min_coverage_mean": args.min_coverage_mean, - "min_norm_entropy": min_norm_entropy, - "min_norm_gini": min_norm_gini, - } - - if args.stats_filtered is not None: - filter_reference_BAM( - bam=bam, - df=data_df, - filter_conditions=filter_conditions, - transform_cov_evenness=args.transform_cov_evenness, - threads=args.threads, - out_files=out_files, - sort_memory=args.sort_memory, - sort_by_name=args.sort_by_name, - min_read_ani=args.min_read_ani, - disable_sort=args.disable_sort, - ) - else: - logging.info("Skipping filtering of reference BAM file.") - if args.low_memory: - os.remove(out_files["bam_tmp_sorted"]) - logging.info("ALL DONE.") + if args.action == "reassign": + reassign(args) + elif args.action == "filter": + filter_references(args) + elif args.action == "lca": + do_lca(args) if __name__ == "__main__": diff --git a/bam_filter/filter.py b/bam_filter/filter.py new file mode 100644 index 0000000..4e47bdb --- /dev/null +++ b/bam_filter/filter.py @@ -0,0 +1,252 @@ +""" +This program is free software: you can redistribute it and/or modify it under the terms of the GNU +General Public License as published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +General Public License for more details. + +You should have received a copy of the GNU General Public License along with this program. If not, +see . +""" + + +import logging +import pandas as pd +from bam_filter.sam_utils import process_bam, filter_reference_BAM, check_bam_file +from bam_filter.utils import ( + get_arguments, + create_output_files, + concat_df, + check_tmp_dir_exists, + handle_warning, +) +from bam_filter.entropy import find_knee +import json +import warnings +from collections import Counter +from functools import reduce +import os + +log = logging.getLogger("my_logger") + + +def obj_dict(obj): + return obj.__dict__ + + +def get_summary(obj): + return obj.to_summary() + + +def get_lens(obj): + return obj.get_read_length_freqs() + + +def filter_references(args): + logging.basicConfig( + level=logging.DEBUG, + format="%(levelname)s ::: %(asctime)s ::: %(message)s", + datefmt="%H:%M:%S", + ) + + args = get_arguments() + + tmp_dir = check_tmp_dir_exists(args.tmp_dir) + log.info("Temporary directory: %s", tmp_dir.name) + if args.trim_min >= args.trim_max: + log.error("trim_min must be less than trim_max") + exit(1) + + logging.getLogger("my_logger").setLevel( + logging.DEBUG if args.debug else logging.INFO + ) + + if args.debug: + warnings.showwarning = handle_warning + else: + warnings.filterwarnings("ignore") + + plot_coverage = False + if args.coverage_plots is not None: + plot_coverage = True + + if args.stats_filtered is None and (args.bam_filtered is not None): + logging.error( + "You need to specify a filtereds stats file to obtain the filtered BAM file" + ) + exit(1) + + out_files = create_output_files( + prefix=args.prefix, + bam=args.bam, + stats=args.stats, + stats_filtered=args.stats_filtered, + bam_filtered=args.bam_filtered, + read_length_freqs=args.read_length_freqs, + read_hits_count=args.read_hits_count, + knee_plot=args.knee_plot, + coverage_plots=args.coverage_plots, + tmp_dir=tmp_dir, + ) + + bam = check_bam_file( + bam=args.bam, + threads=args.threads, + reference_lengths=args.reference_lengths, + sort_memory=args.sort_memory, + ) + data = process_bam( + bam=bam, + threads=args.threads, + reference_lengths=args.reference_lengths, + min_read_count=args.min_read_count, + min_read_ani=args.min_read_ani, + trim_ends=args.trim_ends, + trim_min=args.trim_min, + trim_max=args.trim_max, + scale=args.scale, + plot=plot_coverage, + plots_dir=out_files["coverage_plot_dir"], + read_length_freqs=args.read_length_freqs, + output_files=out_files, + sort_memory=args.sort_memory, + low_memory=args.low_memory, + ) + if args.low_memory: + bam = out_files["bam_tmp_sorted"] + + logging.info("Reducing results to a single dataframe") + # data = list(filter(None, data)) + data_df = [x[0] for x in data if x[0] is not None] + data_df = concat_df(data_df) + + if args.read_length_freqs is not None: + logging.info("Calculating read length frequencies...") + lens = [x[1] for x in data if x[1] is not None] + lens = json.dumps(lens, default=obj_dict, ensure_ascii=False, indent=4) + with open(out_files["read_length_freqs"], "w", encoding="utf-8") as outfile: + print(lens, file=outfile) + + if args.read_hits_count is not None: + logging.info("Calculating read hits counts...") + hits = [x[2] for x in data if x[2] is not None] + + # merge dicts and sum values + hits = reduce(lambda x, y: x.update(y) or x, (Counter(dict(x)) for x in hits)) + # hits = sum(map(Counter, hits), Counter()) + + # convert dict to dataframe + hits = ( + pd.DataFrame.from_dict(hits, orient="index", columns=["count"]) + .rename_axis("read") + .reset_index() + .sort_values(by="count", ascending=False) + ) + + hits.to_csv( + out_files["read_hits_count"], + sep="\t", + index=False, + ) + + logging.info(f"Writing reference statistics to {out_files['stats']}") + data_df.to_csv(out_files["stats"], sep="\t", index=False) + + if args.min_norm_entropy is None or args.min_norm_gini is None: + filter_conditions = { + "min_read_length": args.min_read_length, + "min_read_count": args.min_read_count, + "min_expected_breadth_ratio": args.min_expected_breadth_ratio, + "min_breadth": args.min_breadth, + "min_avg_read_ani": args.min_avg_read_ani, + "min_coverage_evenness": args.min_coverage_evenness, + "min_coeff_var": args.min_coeff_var, + "min_coverage_mean": args.min_coverage_mean, + } + elif args.min_norm_entropy == "auto" or args.min_norm_gini == "auto": + if data_df.shape[0] > 1: + min_norm_gini, min_norm_entropy = find_knee( + data_df, out_plot_name=out_files["knee_plot"] + ) + + if min_norm_gini is None or min_norm_entropy is None: + logging.warning( + "Could not find knee in entropy plot. Disabling filtering by entropy/gini inequality." + ) + filter_conditions = { + "min_read_length": args.min_read_length, + "min_read_count": args.min_read_count, + "min_expected_breadth_ratio": args.min_expected_breadth_ratio, + "min_breadth": args.min_breadth, + "min_avg_read_ani": args.min_avg_read_ani, + "min_coverage_evenness": args.min_coverage_evenness, + "min_coeff_var": args.min_coeff_var, + "min_coverage_mean": args.min_coverage_mean, + } + else: + filter_conditions = { + "min_read_length": args.min_read_length, + "min_read_count": args.min_read_count, + "min_expected_breadth_ratio": args.min_expected_breadth_ratio, + "min_breadth": args.min_breadth, + "min_avg_read_ani": args.min_avg_read_ani, + "min_coverage_evenness": args.min_coverage_evenness, + "min_coeff_var": args.min_coeff_var, + "min_coverage_mean": args.min_coverage_mean, + "min_norm_entropy": min_norm_entropy, + "min_norm_gini": min_norm_gini, + } + else: + min_norm_gini = 0.5 + min_norm_entropy = 0.75 + logging.warning( + f"There's only one genome. Using min_norm_gini={min_norm_gini} and min_norm_entropy={min_norm_entropy}. Please check the results." + ) + filter_conditions = { + "min_read_length": args.min_read_length, + "min_read_count": args.min_read_count, + "min_expected_breadth_ratio": args.min_expected_breadth_ratio, + "min_breadth": args.min_breadth, + "min_avg_read_ani": args.min_avg_read_ani, + "min_coverage_evenness": args.min_coverage_evenness, + "min_coeff_var": args.min_coeff_var, + "min_coverage_mean": args.min_coverage_mean, + "min_norm_entropy": min_norm_entropy, + "min_norm_gini": min_norm_gini, + } + else: + min_norm_gini, min_norm_entropy = args.min_norm_gini, args.min_norm_entropy + filter_conditions = { + "min_read_length": args.min_read_length, + "min_read_count": args.min_read_count, + "min_expected_breadth_ratio": args.min_expected_breadth_ratio, + "min_breadth": args.min_breadth, + "min_avg_read_ani": args.min_avg_read_ani, + "min_coverage_evenness": args.min_coverage_evenness, + "min_coeff_var": args.min_coeff_var, + "min_coverage_mean": args.min_coverage_mean, + "min_norm_entropy": min_norm_entropy, + "min_norm_gini": min_norm_gini, + } + + if args.stats_filtered is not None: + filter_reference_BAM( + bam=bam, + df=data_df, + filter_conditions=filter_conditions, + transform_cov_evenness=args.transform_cov_evenness, + threads=args.threads, + out_files=out_files, + sort_memory=args.sort_memory, + sort_by_name=args.sort_by_name, + min_read_ani=args.min_read_ani, + disable_sort=args.disable_sort, + ) + else: + logging.info("Skipping filtering of reference BAM file.") + + if args.low_memory: + os.remove(out_files["bam_tmp_sorted"]) + logging.info("ALL DONE.") diff --git a/bam_filter/lca.py b/bam_filter/lca.py new file mode 100644 index 0000000..e291f57 --- /dev/null +++ b/bam_filter/lca.py @@ -0,0 +1,585 @@ +import pysam +import taxopy as txp +from tqdm import tqdm +from multiprocessing import Pool, Manager +from functools import partial +import pandas as pd +import logging +import networkx as nx +from concurrent.futures import ProcessPoolExecutor +import sys +from bam_filter.utils import ( + calc_chunksize, + sort_keys_by_approx_weight, + concat_df, + is_debug, + create_output_files, +) +from collections import defaultdict +from functools import reduce +import operator +import random +import gzip + +log = logging.getLogger("my_logger") + +debug = is_debug() + + +def calculate_path_likelihood(path, graph): + likelihood = 1.0 + for u, v in zip(path[:-1], path[1:]): + likelihood *= graph[u][v]["cum_weight"] + return likelihood + + +def find_most_likely_continuation_worker(partial_path_end, full_graph, index): + # find where are we in the taxonomy tree, if we are in the root, return + + results = [] + try: + descendants = list(nx.descendants(full_graph, partial_path_end)) + except nx.NetworkXError: + descendants = [] + + if not descendants: + return ( + partial_path_end, + { + "reference": None, + "best_path": None, + "top_10_paths": None, + }, + ) + + if ( + len(nx.shortest_path(full_graph, source="root", target=partial_path_end)) + <= index + 1 + ): + print(partial_path_end) + return ( + partial_path_end, + { + "reference": None, + "best_path": None, + "top_10_paths": None, + }, + ) + + for continuation in descendants: + full_path = list( + nx.all_simple_paths( + full_graph, source=partial_path_end, target=continuation + ) + )[0] + likelihood = calculate_path_likelihood(full_path, full_graph) + results.append((full_path, likelihood)) + + results.sort(key=lambda x: x[1], reverse=True) + + return ( + partial_path_end, + { + "reference": results[0][0][-1] if results else None, + "best_path": results[0] if results else None, + "top_10_paths": results[:10] if results else None, + }, + ) + + +def find_most_likely_continuation(full_graph, leaves, index, num_workers=1): + result_dict = {} + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit( + find_most_likely_continuation_worker, + partial_path_end, + full_graph, + index, + ) + for partial_path_end in leaves + ] + + for future in tqdm(futures, total=len(leaves), leave=False, ncols=80): + partial_path_end, result = future.result() + result_dict[partial_path_end] = result + + return result_dict + + +def calculate_cumulative_weight_and_path(graph, node, lengths): + # Base case: if the node is the root, return its cumulative weight and path + if not list(graph.predecessors(node)): + if node in lengths: + return 0, 0, [node] + else: + nei = list(graph.neighbors(node))[0] + cumulative_weight = graph[node][nei].get("weight", 0) + cumulative_norm_weight = graph[node][nei].get("norm_weight", 0) + return cumulative_weight, cumulative_norm_weight, [node] + + # Recursive case: calculate cumulative weight and path by summing the edge weight + # and norm_weight with the cumulative weight and path of its parent(s) + cumulative_weight = 0 + cumulative_norm_weight = 0 + cumulative_path = [] + + for parent in graph.predecessors(node): + edge_weight = graph[parent][node].get("weight", 0) + norm_weight = graph[parent][node].get("norm_weight", 0) + ( + parent_weight, + parent_norm_weight, + parent_path, + ) = calculate_cumulative_weight_and_path(graph, parent, lengths) + + cumulative_weight += edge_weight + parent_weight + cumulative_norm_weight += norm_weight + parent_norm_weight + cumulative_path.extend(parent_path + [node]) # Fix: Use += to concatenate lists + + return cumulative_weight, cumulative_norm_weight, cumulative_path + + +def create_tax_graph_w(tax_path, weight, lengths, scale=1_000_000): + root_row = pd.DataFrame({"source": ["root"], "target": ["root"], "weight": 0}) + res = list(zip(tax_path, tax_path[1:])) + # get last element + res = pd.DataFrame(res, columns=["source", "target"]) + res = pd.concat([root_row, res]) + res = res.drop_duplicates() + res["weight"] = 0 + res["norm_weight"] = 0 + res.iloc[-1, res.columns.get_loc("weight")] = weight + res.iloc[-1, res.columns.get_loc("norm_weight")] = round( + (weight / lengths[res.iloc[-1, res.columns.get_loc("target")]]) * scale + ) + return res + + +def create_lca_df(tax_path, weight): + root_row = pd.DataFrame({"source": ["root"], "target": ["root"], "weight": 0}) + res = list(zip(tax_path, tax_path[1:])) + # get last element + res = pd.DataFrame(res, columns=["source", "target"]) + res = pd.concat([root_row, res]) + res = res.drop_duplicates() + res["weight"] = 0 + # add 1 to the last row weight + res.iloc[-1, res.columns.get_loc("weight")] = weight + return res + + +def get_ref2read(params, dat, threads=1): + bam, references = params + samfile = pysam.AlignmentFile(bam, "rb", threads=threads) + results = defaultdict(set) + for reference in references: + if reference not in dat: + continue + for aln in samfile.fetch( + contig=reference, multiple_iterators=False, until_eof=True + ): + results[reference].add(aln.query_name) + samfile.close() + return results + + +def get_tax(ref, parms): + taxdb = parms["taxdb"] + acc2taxid = parms["acc2taxid"] + if ref in acc2taxid: + taxid = acc2taxid[ref] + # taxid = txp.taxid_from_name(ref, taxdb)[0] + taxonomy_info = txp.Taxon(taxid, taxdb).rank_name_dictionary + taxonomy_info["taxid"] = taxid + taxonomy_info["ref"] = ref + taxonomy_info["subspecies"] = f"S__{ref}" + else: + log.debug(f"No taxid found for {ref}") + taxonomy_info = None + return taxonomy_info + + +def get_taxonomy_info(refids, taxdb, acc2taxid, nprocs=1): + """Function to get the references taxonomic information for a given taxonomy id + + Args: + taxids (list): A list of taxonomy ids + taxdb (taxopy.TaxonomyDB): A taxopy DB + + Returns: + dict: A list of taxonomy information + """ + + acc2taxid_df = pd.read_csv(acc2taxid, sep="\t", index_col=None)[ + ["accession", "taxid"] + ].rename(columns={"accession": "reference"}, inplace=False) + # Filter rows in refids from dataframe + acc2taxid_df = acc2taxid_df.loc[acc2taxid_df["reference"].isin(refids)] + acc2taxid_dict = acc2taxid_df.set_index("reference").T.to_dict("records") + + parms = {"taxdb": taxdb, "acc2taxid": acc2taxid_dict[0]} + func = partial(get_tax, parms=parms) + if debug is True or len(refids) < 100000: + taxonomy_info = list(map(func, refids)) + else: + p = Pool(nprocs) + c_size = calc_chunksize(nprocs, len(refids)) + taxonomy_info = list( + tqdm( + p.imap_unordered(func, refids, chunksize=c_size), + total=len(refids), + leave=False, + ncols=100, + desc="References processed", + ) + ) + p.close() + p.join() + taxonomy_info = list(filter(None, taxonomy_info)) + exclude = ["taxid", "ref"] + tax_ranks = [] + + for k in taxonomy_info[0].keys(): + if k not in exclude: + tax_ranks.append(k) + + taxonomy_info = {i["ref"]: i for i in taxonomy_info} + return taxonomy_info, tax_ranks + + +def do_lca(args): + bam = args.bam + names = args.names + nodes = args.nodes + acc2taxid = args.acc2taxid + sel_rank = args.rank_lca + reference_lengths = args.reference_lengths + threads = args.threads + scale = args.scale + + out_files = create_output_files( + bam=bam, prefix=args.prefix, lca_summary=args.lca_summary, tmp_dir=None + ) + + samfile = pysam.AlignmentFile(bam, "rb", threads=threads) + references = samfile.references + references_m = { + chrom.contig: chrom.mapped for chrom in samfile.get_index_statistics() + } + + if reference_lengths is not None: + ref_lengths = pd.read_csv( + reference_lengths, sep="\t", index_col=0, names=["reference", "length"] + ) + ref_lengths = dict(zip(ref_lengths["reference"], ref_lengths["length"])) + # check if the dataframe contains all the References in the BAM file + if not set(references).issubset(set(ref_lengths.index)): + logging.error( + "The BAM file contains references not found in the reference lengths file" + ) + sys.exit(1) + else: + ref_lengths = {x: samfile.get_reference_length(x) for x in references} + + log.info("Getting taxonomy information") + taxdb = txp.TaxDb( + nodes_dmp=nodes, + names_dmp=names, + ) + acc2taxid = acc2taxid + taxonomy_info, tax_ranks = get_taxonomy_info( + references, taxdb, acc2taxid, nprocs=threads + ) + tax_ranks.reverse() + index_r = tax_ranks.index(sel_rank) + + dat = {} + for k, v in taxonomy_info.items(): + try: + tax_list = ["root"] + tax_list.extend([v[rank] for rank in tax_ranks]) + tax_list.append(k) + dat[k] = tuple(tax_list) + except KeyError: + continue + + log.info("Getting reference to read mapping") + ref_chunks = sort_keys_by_approx_weight( + references_m, scale=1.5, num_cores=threads, refinement_steps=100 + ) + + ref_chunks = random.sample(ref_chunks, len(ref_chunks)) + dat_man = Manager().dict(dat) + params = zip([bam] * len(ref_chunks), ref_chunks) + p = Pool( + threads, + ) + data = list( + tqdm( + p.imap_unordered( + partial( + get_ref2read, + dat=dat_man, + threads=1, + ), + params, + chunksize=1, + ), + total=len(ref_chunks), + leave=False, + ncols=80, + desc="Chunks processed", + ) + ) + + p.close() + p.join() + + data = reduce(operator.ior, data, {}) + + log.info("Getting reads taxonomic information") + reads = defaultdict(set) + for k, v in tqdm(data.items(), total=len(data), leave=False, ncols=80): + tax = dat[k] + for read in v: + reads[read].add(tax) + + log.info("Creating taxonomic graph") + root_row = pd.DataFrame({"source": ["root"], "target": ["root"], "weight": [0]}) + gs = [] + if len(dat) > 0: + for t in tqdm(dat.values(), total=len(dat), leave=False, ncols=80): + # remove last element + # t = t[:-1] + res = list(zip(t, t[1:])) + res = pd.DataFrame(res, columns=["source", "target"]) + res = res.drop_duplicates() + res["weight"] = 0 + res = pd.concat([root_row, res]) + gs.append(res) + gs = concat_df(gs).drop_duplicates() + + G = nx.from_pandas_edgelist(gs, create_using=nx.DiGraph(), edge_attr=True) + + dat = None + + log.info("Calculating LCA") + cum_tax = defaultdict(int) + unique = [] + + discarded_lca = defaultdict(int) + for_lca = defaultdict(int) + + # for k, v in tqdm(reads.items(), total=len(reads), leave=False, ncols=80): + # v = list(v) + # if len(v) > 1: + # tax_path = [edge for edge in v[0] if all(edge in t for t in v[1:])] + # if len(tax_path) <= index_r + 1: + # discarded_lca[tuple(tax_path)] += 1 + # continue + # for_lca[tuple(tax_path)] += 1 + # else: + # unique.append(k) + # cum_tax[v[0]] += 1 + + for k, v in tqdm(reads.items(), total=len(reads), leave=False, ncols=80): + v = list(v) + if len(v) > 1: + tax_path_set = set(v[0]) + for t in v[1:]: + tax_path_set.intersection_update(t) + + tax_path = list(tax_path_set) + + if len(tax_path) <= index_r + 1: + discarded_lca[tuple(tax_path)] += 1 + continue + + for_lca[tuple(tax_path)] += 1 + else: + unique.append(k) + cum_tax[v[0]] += 1 + + log.info( + f"Unique: {len(unique)} | LCA: {len(for_lca)} | Discarded: {len(discarded_lca)}" + ) + log.info("Creating LCA dataframe") + if len(for_lca) > 0: + for_lca_df = [] + for k, v in tqdm(for_lca.items(), total=len(for_lca), leave=False, ncols=80): + for_lca_df.append(create_lca_df(list(k), v)) + + log.info("Adding weights to the unique mapping taxonomic graph") + gs = [] + for k, v in tqdm(cum_tax.items(), total=len(cum_tax), leave=False, ncols=80): + gs.append(create_tax_graph_w(list(k), v, ref_lengths, scale=scale)) + + df = concat_df(gs) + if len(for_lca) > 0: + df_l = concat_df(for_lca_df) + df_l["weight"] = 0 + df_l["norm_weight"] = 0 + df = concat_df([df, df_l]) + # group by source and target and sum weight + df = df.groupby(["source", "target"]).sum().reset_index() + + G = nx.from_pandas_edgelist(df, create_using=nx.DiGraph(), edge_attr=True) + G.remove_edges_from(nx.selfloop_edges(G)) + Gr = G.reverse() + + log.info("Calculating cumulative weights and paths on the taxonomic graph") + cumulative_weights_and_paths = {} + modified_graph = Gr.copy() # Create a copy of the original graph + + for node in Gr.nodes: + ( + cumulative_weight, + cumulative_norm_weight, + path, + ) = calculate_cumulative_weight_and_path(modified_graph, node, ref_lengths) + cumulative_weights_and_paths[node] = { + "cumulative_weight": cumulative_weight, + "cumulative_norm_weight": cumulative_norm_weight, + "path": ";".join(nx.shortest_path(G, target=node)["root"]), + } + + # Update the modified graph with inferred edge attributes after all weights have been calculated + for parent in modified_graph.predecessors(node): + modified_graph[parent][node]["cum_weight"] = ( + modified_graph[parent][node].get("weight", 0) + cumulative_weight + ) + modified_graph[parent][node]["cum_norm_weight"] = ( + modified_graph[parent][node].get("norm_weight", 0) + + cumulative_norm_weight + ) + + cumulative_weights_and_paths = dict( + sorted(cumulative_weights_and_paths.items(), key=lambda item: item[1]["path"]) + ) + + if len(for_lca) > 0: + df1 = concat_df(for_lca_df) + df1 = df1.groupby(["source", "target"]).sum().reset_index() + G_lca = nx.from_pandas_edgelist(df1, create_using=nx.DiGraph(), edge_attr=True) + G_lca.remove_edges_from(nx.selfloop_edges(G_lca)) + + df1_d = dict( + zip(df1[df1["weight"] > 0]["target"], df1[df1["weight"] > 0]["weight"]) + ) + leaves = list(df1_d.keys()) + + log.info( + f"Finding most likely reference for the LCA nodes [{threads} threads]]" + ) + test = find_most_likely_continuation( + full_graph=modified_graph.reverse(), + leaves=leaves, + index=index_r, + num_workers=threads, + ) + + lca_dfs = [] + for k, v in test.items(): + tax_path = nx.shortest_path(G, target=k)["root"] + root_row = pd.DataFrame( + {"source": ["root"], "target": ["root"], "weight": 0} + ) + res = list(zip(tax_path, tax_path[1:])) + # get last element + res = pd.DataFrame(res, columns=["source", "target"]) + res = pd.concat([root_row, res]) + res = res.drop_duplicates() + res["weight"] = 0 + res["norm_weight"] = 0 + res.iloc[-1, res.columns.get_loc("weight")] = df1_d[k] + if v["reference"] is None: + res.iloc[-1, res.columns.get_loc("norm_weight")] = df1_d[k] + else: + res.iloc[-1, res.columns.get_loc("norm_weight")] = round( + scale * df1_d[k] / ref_lengths[v["reference"]] + ) + lca_dfs.append(res) + log.info("Adding LCA nodes to the taxonomic graph") + df2 = concat_df(lca_dfs) + df2 = concat_df([df2, df]) + df2 = df2.groupby(["source", "target"]).sum().reset_index() + # %% + G2 = nx.from_pandas_edgelist(df2, create_using=nx.DiGraph(), edge_attr=True) + G2.remove_edges_from(nx.selfloop_edges(G2)) + Gr2 = G2.reverse() + # %% + log.info("Calculating cumulative weights and paths on the taxonomic graph") + cumulative_weights_and_paths = {} + modified_graph = Gr2.copy() # Create a copy of the original graph + + for node in Gr2.nodes: + ( + cumulative_weight, + cumulative_norm_weight, + path, + ) = calculate_cumulative_weight_and_path(modified_graph, node, ref_lengths) + + cumulative_weights_and_paths[node] = { + "cumulative_weight": cumulative_weight, + "cumulative_norm_weight": cumulative_norm_weight, + "path": ";".join(nx.shortest_path(G, target=node)["root"]), + } + + # Update the modified graph with inferred edge attributes after all weights have been calculated + for parent in modified_graph.predecessors(node): + modified_graph[parent][node]["cum_weight"] = ( + modified_graph[parent][node].get("weight", 0) + cumulative_weight + ) + modified_graph[parent][node]["cum_norm_weight"] = ( + modified_graph[parent][node].get("norm_weight", 0) + + cumulative_norm_weight + ) + cumulative_weights_and_paths = dict( + sorted( + cumulative_weights_and_paths.items(), key=lambda item: item[1]["path"] + ) + ) + + log.info("Writing LCA results to file") + nodes = list( + set( + [ + node + for node, data in cumulative_weights_and_paths.items() + if data["cumulative_weight"] > 0 + ] + ) + ) + taxids = txp.taxid_from_name(nodes, taxdb) + taxids = dict(zip(nodes, taxids)) + + out_file = out_files["lca_summary"] + # determine if the file is gzipped or not + if out_file.endswith(".gz"): + with gzip.open(out_file, "wt") as f: + f.write("taxid\tname\trank\tn_reads\tabundance\ttax_path\n") + for node, data in tqdm( + cumulative_weights_and_paths.items(), + total=len(cumulative_weights_and_paths), + ): + if data["cumulative_weight"] > 0: + taxid = taxids[node][0] + rank = taxdb.taxid2rank[taxid] + f.write( + f"{taxid}\t{node}\t{rank}\t{data['cumulative_weight']}\t{data['cumulative_norm_weight']}\t{data['path']}\n" + ) + else: + with open(out_file, "wt") as f: + f.write("taxid\tname\trank\tn_reads\tabundance\ttax_path\n") + for node, data in tqdm( + cumulative_weights_and_paths.items(), + total=len(cumulative_weights_and_paths), + ): + if data["cumulative_weight"] > 0: + taxid = taxids[node][0] + rank = taxdb.taxid2rank[taxid] + f.write( + f"{taxid}\t{node}\t{rank}\t{data['cumulative_weight']}\t{data['cumulative_norm_weight']}\t{data['path']}\n" + ) diff --git a/bam_filter/reassign.py b/bam_filter/reassign.py new file mode 100644 index 0000000..69afbe5 --- /dev/null +++ b/bam_filter/reassign.py @@ -0,0 +1,785 @@ +import datatable as dt +import logging +import numpy as np +import pysam +import tqdm +import sys +import random +from bam_filter.utils import ( + create_empty_output_files, + sort_keys_by_approx_weight, + is_debug, + get_arguments, + check_tmp_dir_exists, + handle_warning, + create_output_files, +) +from multiprocessing import Pool, Manager, cpu_count +from functools import partial +import numpy.lib.recfunctions as rf +import gc +from collections import defaultdict +import os +import concurrent.futures +import math +import warnings + +# import cProfile as prof +# import pstats + +log = logging.getLogger("my_logger") + + +def initialize_subject_weights(data): + if data.shape[0] > 0: + # Add a new column for inverse sequence length + data["s_W"] = 1 / data["slen"] + + # Calculate the sum of weights for each unique source + sum_weights = np.zeros(int(np.max(data["source"])) + 1) + np.add.at(sum_weights, data["source"], data["var"]) + + # Calculate the normalized weights based on the sum for each query + query_sum_var = sum_weights[data["source"]] + data["prob"] = data["var"] / query_sum_var + + return data + else: + return None + + +def resolve_multimaps(data, scale=0.9, iters=10): + current_iter = 0 + while True: + progress_bar = tqdm.tqdm( + total=9, + desc=f"Iter {current_iter + 1}", + unit="step", + disable=False, # Replace with your logic or a boolean value + leave=False, + ncols=80, + ) + log.debug(f"::: Iter: {current_iter + 1} - Getting scores") + progress_bar.update(1) + n_alns = data.shape[0] + log.debug(f"::: Iter: {current_iter + 1} - Total alignment: {n_alns:,}") + + # Calculate the weights for each subject + log.debug(f"::: Iter: {current_iter + 1} - Calculating weights...") + progress_bar.update(1) + subject_weights = np.zeros(int(np.max(data["subject"])) + 1) + np.add.at(subject_weights, data["subject"], data["prob"]) + data["s_W"] = subject_weights[data["subject"]] / data["slen"] + subject_weights = None + + log.debug(f"::: Iter: {current_iter + 1} - Calculating probabilities") + progress_bar.update(1) + # Calculate the alignment probabilities + new_prob = data["prob"] * data["s_W"] + log.debug("Calculating sum of probabilities") + progress_bar.update(1) + prob_sum = data["prob"] * data["s_W"] + prob_sum_array = np.zeros(int(np.max(data["source"])) + 1) + np.add.at(prob_sum_array, data["source"], prob_sum) + prob_sum = None + + # data["prob_sum"] = prob_sum_array[data["source"]] + data["prob"] = new_prob / prob_sum_array[data["source"]] + prob_sum_array = None + + log.debug("Calculating query counts") + progress_bar.update(1) + # Calculate how many alignments are in each query + # query_counts = np.zeros(int(np.max(data["source"])) + 1) + # np.add.at(query_counts, data["source"], 1) + + query_counts = np.bincount(data["source"]) + + log.debug("Calculating query counts array") + progress_bar.update(1) + # Use a separate array for query counts + query_counts_array = np.zeros(int(np.max(data["source"])) + 1) + np.add.at( + query_counts_array, + data["source"], + query_counts[data["source"]], + ) + + log.debug( + f"::: Iter: {current_iter + 1} - Calculating number of alignments per query" + ) + progress_bar.update(1) + data["n_aln"] = query_counts_array[data["source"]] + + log.debug("Calculating unique alignments") + data["n_aln"] = query_counts_array[data["source"]] + data_unique = data[data["n_aln"] == 1] + n_unique = data_unique.shape[0] + + data = data[(data["n_aln"] > 1) & (data["prob"] > 0)] + + # total_n_unique = np.sum(query_counts_array[data["source"]] <= 1) + + query_counts = None + query_counts_array = None + + log.debug("Calculating max_prob") + # Keep the ones that have a probability higher than the maximum scaled probability + max_prob = np.zeros(int(np.max(data["source"])) + 1) + np.maximum.at(max_prob, data["source"], data["prob"]) + + data["max_prob"] = max_prob[data["source"]] + data["max_prob"] = data["max_prob"] * scale + # data["max_prob"] = max_prob[data["source"]] + log.debug( + f"::: Iter: {current_iter + 1} - Removing alignments with lower probability" + ) + progress_bar.update(1) + to_remove = np.sum(data["prob"] < data["max_prob"]) + + data = data[data["prob"] >= data["max_prob"]] + max_prob = None + + # Update the iteration count in the function call + current_iter += 1 + data["iter"] = current_iter + data_unique["iter"] = current_iter + + query_counts = np.bincount(data["source"]) + total_n_unique = np.sum(query_counts[data["source"]] <= 1) + + # data_unique["iter"] = current_iter + + # data = np.concatenate([data, data_unique]) + data = np.concatenate([data, data_unique]) + data_unique = None + + keep_processing = to_remove != 0 + log.debug(f"::: Iter: {current_iter} - Removed {to_remove:,} alignments") + log.debug(f"::: Iter: {current_iter} - Total mapping queries: {n_unique:,}") + log.debug( + f"::: Iter: {current_iter} - New unique mapping queries: {total_n_unique:,}" + ) + log.debug(f"::: Iter: {current_iter} - Alns left: {data.shape[0]:,}") + progress_bar.update(1) + progress_bar.close() + log.info( + f"::: Iter: {current_iter} - R: {to_remove:,} | U: {total_n_unique:,} | NU: {n_unique:,} | L: {data.shape[0]:,}" + ) + log.debug(f"::: Iter: {current_iter} - done!") + + if iters > 0 and current_iter >= iters: + log.info("::: ::: Reached maximum iterations. Stopping.") + break + elif not keep_processing: + log.info("::: ::: No more alignments to remove. Stopping.") + break + return data + + +# def write_reassigned_bam( +# bam, out_files, threads, entries, sort_memory="1G", min_read_ani=90 +# ): +# if out_files["bam_reassigned"] is not None: +# out_bam = out_files["bam_reassigned"] +# else: +# out_bam = out_files["bam_reassigned_sorted"] + +# samfile = pysam.AlignmentFile(bam, "rb", threads=threads) +# references = list(entries.keys()) +# refs_dict = {x: samfile.get_reference_length(x) for x in list(entries.keys())} + +# (ref_names, ref_lengths) = zip(*refs_dict.items()) + +# refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)} +# if threads > 4: +# write_threads = 4 +# else: +# write_threads = threads + +# out_bam_file = pysam.AlignmentFile( +# out_files["bam_reassigned_tmp"], +# "wb", +# referencenames=list(ref_names), +# referencelengths=list(ref_lengths), +# threads=write_threads, +# ) + +# for reference in tqdm.tqdm( +# references, +# total=len(references), +# leave=False, +# ncols=80, +# desc="References processed", +# ): +# r_ids = entries[reference] +# for aln in samfile.fetch( +# reference=reference, multiple_iterators=False, until_eof=True +# ): +# # ani_read = (1 - ((aln.get_tag("NM") / aln.infer_query_length()))) * 100 +# if (aln.query_name, reference) in r_ids: +# aln.reference_id = refs_idx[aln.reference_name] +# out_bam_file.write(aln) +# out_bam_file.close() + + +# def write_to_file(alns, out_bam_file): +# for aln in tqdm.tqdm(alns, total=len(alns), leave=False, ncols=80, desc="Writing"): +# out_bam_file.write(aln) + + +def write_to_file(alns, out_bam_file, header=None): + for aln in alns: + out_bam_file.write(pysam.AlignedSegment.fromstring(aln, header)) + + +def process_references_batch(references, entries, bam, refs_idx, threads=4): + alns = [] + with pysam.AlignmentFile(bam, "rb", threads=threads) as samfile: + for reference in references: + r_ids = entries[reference] + for aln in samfile.fetch( + reference=reference, multiple_iterators=False, until_eof=True + ): + if (aln.query_name, reference) in r_ids: + aln.reference_id = refs_idx[aln.reference_name] + alns.append(aln.to_string()) + + return alns + + +def write_reassigned_bam( + bam, + ref_counts, + out_files, + threads, + entries, + sort_memory="1G", + sort_by_name=False, + min_read_ani=90, + min_read_length=30, +): + # if out_files["bam_reassigned"] is not None: + # out_bam = out_files["bam_reassigned"] + # else: + # out_bam = out_files["bam_reassigned_sorted"] + out_bam = out_files["bam_reassigned"] + samfile = pysam.AlignmentFile(bam, "rb", threads=threads) + references = list(entries.keys()) + refs_dict = {x: samfile.get_reference_length(x) for x in references} + header = samfile.header + samfile.close() + (ref_names, ref_lengths) = zip(*refs_dict.items()) + + refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)} + if threads > 4: + write_threads = 4 + else: + write_threads = threads + + out_bam_file = pysam.AlignmentFile( + out_files["bam_reassigned_tmp"], + "wb", + referencenames=list(ref_names), + referencelengths=list(ref_lengths), + threads=write_threads, + ) + + num_cores = min(threads, cpu_count()) + # batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size + # batch_size = calc_chunksize(n_workers=num_cores, len_iterable=len(references)) + log.info("::: Creating reference chunks with uniform read amounts...") + ref_chunks = sort_keys_by_approx_weight( + input_dict=ref_counts, scale=1.5, num_cores=threads + ) + num_cores = min(num_cores, len(ref_chunks)) + log.info(f"::: Using {num_cores} cores to write {len(ref_chunks)} chunk(s)") + + with Manager() as manager: + # Use Manager to create a read-only proxy for the dictionary + entries = manager.dict(dict(entries)) + + with concurrent.futures.ProcessPoolExecutor(max_workers=num_cores) as executor: + # Use ProcessPoolExecutor to parallelize the processing of references in batches + futures = [] + for batch_references in tqdm.tqdm( + ref_chunks, + total=len(ref_chunks), + desc="Submitted batches", + unit="batch", + leave=False, + ncols=80, + disable=is_debug(), + ): + future = executor.submit( + process_references_batch, batch_references, entries, bam, refs_idx + ) + futures.append(future) # Store the future + + # Use a while loop to continuously check for completed futures + log.info("::: Collecting batches...") + + completion_progress_bar = tqdm.tqdm( + total=len(futures), + desc="Completed", + unit="batch", + leave=False, + ncols=80, + disable=is_debug(), + ) + completed_count = 0 + + # Use as_completed to iterate over completed futures as they become available + for completed_future in concurrent.futures.as_completed(futures): + alns = completed_future.result() + write_to_file(alns=alns, out_bam_file=out_bam_file, header=header) + + # Update the progress bar for each completed write + completion_progress_bar.update(1) + completed_count += 1 + completed_future.cancel() # Cancel the future to free memory + gc.collect() # Force garbage collection + + completion_progress_bar.close() + out_bam_file.close() + entries = None + gc.collect() + # prof.disable() + # # print profiling output + # stats = pstats.Stats(prof).strip_dirs().sort_stats("tottime") + # stats.print_stats(5) # top 10 rows + log.info("::: ::: Sorting BAM file...") + if sort_by_name: + log.info("::: ::: Sorting by name...") + pysam.sort( + "-n", + "-@", + str(threads), + "-m", + str(sort_memory), + "-o", + out_bam, + out_files["bam_reassigned_tmp"], + ) + else: + pysam.sort( + "-@", + str(threads), + "-m", + str(sort_memory), + "-o", + out_bam, + out_files["bam_reassigned_tmp"], + ) + + logging.info("BAM index not found. Indexing...") + save = pysam.set_verbosity(0) + samfile = pysam.AlignmentFile(out_bam, "rb", threads=threads) + chr_lengths = [] + for chrom in samfile.references: + chr_lengths.append(samfile.get_reference_length(chrom)) + max_chr_length = np.max(chr_lengths) + pysam.set_verbosity(save) + samfile.close() + + if max_chr_length > 536870912: + logging.info("A reference is longer than 2^29, indexing with csi") + pysam.index( + "-c", + "-@", + str(threads), + out_bam, + ) + else: + pysam.index( + "-@", + str(threads), + out_bam, + ) + + os.remove(out_files["bam_reassigned_tmp"]) + + +def calculate_alignment_score(identity, read_length): + return (identity / math.log(read_length + 1)) * math.sqrt(read_length) + + +def get_bam_data(parms, ref_lengths=None, percid=90, min_read_length=30, threads=1): + bam, references = parms + samfile = pysam.AlignmentFile(bam, "rb", threads=threads) + + results = [] + reads = set() + refs = set() + empty_df = 0 + + for reference in references: + if ref_lengths is None: + reference_length = int(samfile.get_reference_length(reference)) + else: + reference_length = int(ref_lengths[reference]) + aln_data = [] + for aln in samfile.fetch( + contig=reference, multiple_iterators=False, until_eof=True + ): + # AS_value = aln.get_tag("AS") if aln.has_tag("AS") else None + # XS_value = aln.get_tag("XS") if aln.has_tag("XS") else None + query_length = ( + aln.query_length if aln.query_length != 0 else aln.infer_query_length() + ) + + # if query_length < min_read_length: + # continue + + pident = (1 - ((aln.get_tag("NM") / query_length))) * 100 + if pident >= percid: + var = calculate_alignment_score(pident, query_length) + aln_data.append( + ( + aln.query_name, + aln.reference_name, + var, + reference_length, + ) + ) + reads.add(aln.query_name) + refs.add(aln.reference_name) + # remove duplicates + # check if aln_data is not empty + if len(aln_data) > 0: + aln_data_dt = dt.Frame(aln_data) + # "queryId", "subjectId", "bitScore", "slen" + aln_data_dt.names = ["queryId", "subjectId", "var", "slen"] + # remove duplicates and keep the ones with the largest mapq + aln_data_dt = aln_data_dt[ + :1, :, dt.by(dt.f.queryId, dt.f.subjectId), dt.sort(-dt.f.var) + ] + # aln_data_dt["slen"] = dt.Frame([ref_lengths[x] for x in aln_data_dt[:,"subjectId"].to_list()[0]]) + results.append(aln_data_dt) + else: + results.append(dt.Frame()) + empty_df += 1 + samfile.close() + return (dt.rbind(results), reads, refs, empty_df) + + +def reassign_reads( + bam, + out_files, + reference_lengths=None, + threads=1, + min_read_count=1, + min_read_ani=90, + min_read_length=30, + reassign_iters=25, + reassign_scale=0.9, + sort_memory="4G", + sort_by_name=False, +): + dt.options.progress.enabled = True + dt.options.progress.clear_on_success = True + if threads > 1: + dt.options.nthreads = threads - 1 + else: + dt.options.nthreads = 1 + + log.info("::: Loading BAM file") + save = pysam.set_verbosity(0) + samfile = pysam.AlignmentFile(bam, "rb", threads=threads) + + references = samfile.references + + pysam.set_verbosity(save) + + if reference_lengths is not None: + ref_len_dt = dt.fread(reference_lengths) + ref_len_dt.names = ["subjectId", "slen"] + # convert to dict + ref_len_dict = dict( + zip(ref_len_dt["subjectId"].to_list()[0], ref_len_dt["slen"].to_list()[0]) + ) + # check if the dataframe contains all the References in the BAM file + if not set(references).issubset(set(ref_len_dict.keys())): + logging.error( + "The BAM file contains references not found in the reference lengths file" + ) + sys.exit(1) + else: + ref_len_dict = None + + total_refs = samfile.nreferences + log.info(f"::: Found {total_refs:,} reference sequences") + # logging.info(f"Found {samfile.mapped:,} alignments") + log.info(f"::: Removing references with less than {min_read_count} reads...") + references_m = { + chrom.contig: chrom.mapped + for chrom in samfile.get_index_statistics() + if chrom.mapped >= min_read_count + } + + references = list(references_m.keys()) + + if len(references) == 0: + log.warning("::: No reference sequences with alignments found in the BAM file") + create_empty_output_files(out_files) + sys.exit(0) + + log.info(f"::: Keeping {len(references):,} references") + + log.info("::: Creating reference chunks with uniform read amounts...") + # ify the number of chunks + ref_chunks = sort_keys_by_approx_weight( + input_dict=references_m, scale=1.5, num_cores=threads + ) + log.info(f"::: Created {len(ref_chunks):,} chunks") + ref_chunks = random.sample(ref_chunks, len(ref_chunks)) + + dt.options.progress.enabled = False + dt.options.progress.clear_on_success = True + dt.options.nthreads = 1 + + parms = list(zip([bam] * len(ref_chunks), ref_chunks)) + + log.info("::: Extracting reads from BAM file...") + if is_debug(): + data = list( + tqdm.tqdm( + map( + partial( + get_bam_data, + ref_lengths=ref_len_dict, + percid=min_read_ani, + min_read_length=min_read_length, + threads=4, + ), + parms, + chunksize=1, + ), + total=len(parms), + leave=False, + ncols=80, + desc="Chunks processed", + ) + ) + else: + p = Pool( + threads, + ) + data = list( + tqdm.tqdm( + p.imap_unordered( + partial( + get_bam_data, + ref_lengths=ref_len_dict, + percid=min_read_ani, + threads=4, + ), + parms, + chunksize=1, + ), + total=len(parms), + leave=False, + ncols=80, + desc="Chunks processed", + ) + ) + + p.close() + p.join() + + dt.options.progress.enabled = True + dt.options.progress.clear_on_success = True + if threads > 1: + dt.options.nthreads = threads - 1 + else: + dt.options.nthreads = 1 + + log.info("::: Collecting results...") + reads = list() + refs = list() + empty_df = 0 + for i in tqdm.tqdm(range(len(data)), total=len(data), leave=False, ncols=80): + empty_df += data[i][3] + reads.extend(list(data[i][1])) + refs.extend(list(data[i][2])) + data[i] = data[i][0] + log.info(f"::: ::: Found {empty_df:,} references without alignments") + + data = dt.rbind([x for x in data]) + + log.info("::: Indexing references...") + refs = dt.Frame(list(set(refs))) + refs.names = ["subjectId"] + refs["sidx"] = dt.Frame(list(range(refs.shape[0]))) + refs.key = "subjectId" + + log.info("::: Indexing reads...") + reads = dt.Frame(list(set(reads))) + reads.names = ["queryId"] + reads["qidx"] = dt.Frame([(i + refs.shape[0]) for i in range(reads.shape[0])]) + reads.key = "queryId" + + log.info("::: Combining data...") + data = data[:, :, dt.join(reads)] + data = data[:, :, dt.join(refs)] + + del data["queryId"] + del data["subjectId"] + n_alns_0 = data.shape[0] + n_reads_0 = reads.shape[0] + n_refs_0 = refs.shape[0] + log.info( + f"::: References: {n_refs_0:,} | Reads: {n_reads_0:,} | Alignments: {n_alns_0:,}" + ) + + log.info("::: Allocating data structures...") + mat = data[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy() + data = None + + # Create a zeros array with the same number of rows as 'mat' + zeros_array = np.zeros((mat.shape[0], 5)) + + # Stack the zeros_array with the original 'mat' + m = np.column_stack([mat, zeros_array]) + zeros_array = None + + dtype = np.dtype( + [ + ("source", "int"), + ("subject", "int"), + ("var", "float"), + ("slen", "int"), + ("s_W", "float"), + ("prob", "float"), + ("iter", "int"), + ("n_aln", "int"), + ("max_prob", "float"), + ] + ) + + # Convert the unstructured array to structured array + m = rf.unstructured_to_structured(m, dtype) + gc.collect() + log.info("::: Initializing data structures...") + init_data = initialize_subject_weights(m) + if reassign_iters > 0: + log.info(f"::: Reassigning reads with {reassign_iters} iterations") + else: + log.info("::: Reassigning reads until convergence") + no_multimaps = resolve_multimaps( + init_data, iters=reassign_iters, scale=reassign_scale + ) + + n_reads = len(list(set(no_multimaps["source"]))) + n_refs = len(list(set(no_multimaps["subject"]))) + n_alns = no_multimaps.shape[0] + log.info( + f"::: References: {n_refs:,} | Reads: {n_reads:,} | Alignments: {n_alns:,}" + ) + log.info( + f'::: Unique mapping reads: {no_multimaps[no_multimaps["n_aln"] == 1].shape[0]:,} | Multimapping reads: {no_multimaps[no_multimaps["n_aln"] > 1].shape[0]:,}' + ) + + # add this to the array + # no_multimaps["n_aln"] = subject_counts_array[no_multimaps["subject"]] + + # log.info(f"::: Removing references with less than {min_read_count} reads...") + # no_multimaps = no_multimaps[no_multimaps["n_aln"] >= min_read_count] + # log.info(f"{no_multimaps.shape[0]:,} alignments left") + log.info("::: Mapping back indices...") + if threads > 1: + dt.options.nthreads = threads - 1 + else: + dt.options.nthreads = 1 + + g = dt.Frame(no_multimaps["source"]) + g.names = ["qidx"] + reads.key = "qidx" + q = g[:, :, dt.join(reads)] + + g = dt.Frame(no_multimaps["subject"]) + g.names = ["sidx"] + refs.key = "sidx" + s = g[:, :, dt.join(refs)] + + log.info("::: Calculating reads per subject...") + # count how many alignments are in each subjectId + s_c = s[:, dt.count(dt.f.subjectId), dt.by(dt.f.subjectId)] + s_c.names = ["subjectId", "counts"] + references_m = dict() + log.info(f"::: Removing references with less than {min_read_count:,}...") + for i, k in zip(s_c[:, "subjectId"].to_list()[0], s_c[:, "counts"].to_list()[0]): + if k >= min_read_count: + references_m[i] = k + log.info(f"::: ::: Keeping {len(references_m):,} references") + s_c = None + # convert columns queryId from q and subjectId from s to a tuple + log.info("::: Creating filtered set...") + entries = defaultdict(set) + q_query_ids = q[:, "queryId"].to_list()[0] + s_subject_ids = s[:, "subjectId"].to_list()[0] + + for query_id, subject_id in zip(q_query_ids, s_subject_ids): + if subject_id in references_m: + entries[subject_id].add((query_id, subject_id)) + no_multimaps = None + q = None + s = None + q_query_ids = None + s_subject_ids = None + gc.collect() + log.info("::: Writing to BAM file...") + write_reassigned_bam( + bam=bam, + ref_counts=references_m, + out_files=out_files, + threads=threads, + entries=entries, + sort_memory=sort_memory, + min_read_ani=min_read_ani, + ) + + +def reassign(args): + logging.basicConfig( + level=logging.DEBUG, + format="%(levelname)s ::: %(asctime)s ::: %(message)s", + datefmt="%H:%M:%S", + ) + + args = get_arguments() + bam = args.bam + + tmp_dir = check_tmp_dir_exists(args.tmp_dir) + log.info("Temporary directory: %s", tmp_dir.name) + + out_files = create_output_files( + prefix=args.prefix, + bam=args.bam, + tmp_dir=tmp_dir, + # bam_reassigned=args.bam_reassigned, + ) + + logging.getLogger("my_logger").setLevel( + logging.DEBUG if args.debug else logging.INFO + ) + + if args.debug: + warnings.showwarning = handle_warning + else: + warnings.filterwarnings("ignore") + logging.info("Resolving multi-mapping reads...") + reassign_reads( + bam=bam, + threads=args.threads, + reference_lengths=args.reference_lengths, + min_read_count=args.min_read_count, + min_read_ani=args.min_read_ani, + min_read_length=args.min_read_length, + reassign_iters=args.reassign_iters, + reassign_scale=args.reassign_scale, + sort_memory=args.sort_memory, + sort_by_name=args.sort_by_name, + out_files=out_files, + ) + log.info("Done!") diff --git a/bam_filter/sam_utils.py b/bam_filter/sam_utils.py index 222d708..175ecc3 100644 --- a/bam_filter/sam_utils.py +++ b/bam_filter/sam_utils.py @@ -3,7 +3,7 @@ import os import sys import pandas as pd -from multiprocessing import Pool +from multiprocessing import Pool, Manager, cpu_count import functools from scipy import stats import tqdm @@ -15,11 +15,15 @@ initializer, create_empty_output_files, create_empty_bam, + sort_keys_by_approx_weight, ) +import random from bam_filter.entropy import entropy, norm_entropy, gini_coeff, norm_gini_coeff from collections import defaultdict import pyranges as pr from pathlib import Path +import concurrent.futures +import gc # import cProfile as profile # import pstats @@ -249,6 +253,9 @@ def get_bam_stats( # prof.enable() bam, references = params results = [] + + if threads > 4: + threads = 4 samfile = pysam.AlignmentFile(bam, "rb", threads=threads) read_hits = defaultdict(int) @@ -339,8 +346,8 @@ def get_bam_stats( mean_coverage_trunc, mean_coverage_trunc_len = get_tad( cov_np, - trim_min=10, - trim_max=90, + trim_min=trim_min, + trim_max=trim_max, ) cov_pos = cov_np[cov_np > 0] @@ -806,7 +813,6 @@ def process_bam( scale=1e6, plot=False, plots_dir="coverage-plots", - chunksize=None, read_length_freqs=False, output_files=None, low_memory=False, @@ -839,14 +845,10 @@ def process_bam( "The BAM file contains references not found in the reference lengths file" ) sys.exit(1) - sys.exit(1) total_refs = samfile.nreferences logging.info(f"Found {total_refs:,} reference sequences") # logging.info(f"Found {samfile.mapped:,} alignments") - logging.info( - f"Removing references without mappings or less than {min_read_count} reads..." - ) # Remove references without mapped reads # alns_in_ref = defaultdict(int) # for aln in tqdm.tqdm( @@ -862,11 +864,28 @@ def process_bam( if not os.path.exists(plots_dir): os.makedirs(plots_dir) - references = [ - chrom.contig - for chrom in samfile.get_index_statistics() - if chrom.mapped >= min_read_count - ] + # references = [ + # chrom.contig + # for chrom in samfile.get_index_statistics() + # if chrom.mapped >= min_read_count + # ] + + logging.info(f"Removing references with less than {min_read_count} reads...") + + references_m = { + chrom.contig: chrom.mapped + for chrom in samfile.get_index_statistics() + if chrom.mapped >= min_read_count + } + else: + logging.info(f"Removing references with less than {min_read_count} reads...") + + references_m = { + chrom.contig: chrom.mapped + for chrom in samfile.get_index_statistics() + if chrom.mapped >= min_read_count + } + references = list(references_m.keys()) if len(references) == 0: logging.warning("No reference sequences with alignments found in the BAM file") @@ -887,18 +906,17 @@ def process_bam( ) samfile = pysam.AlignmentFile(bam, "rb", threads=threads) - if (chunksize is not None) and ((len(references) // chunksize) > threads): - c_size = chunksize - else: - c_size = calc_chunksize( - n_workers=threads, len_iterable=len(references), factor=4 - ) - ref_chunks = [references[i : i + c_size] for i in range(0, len(references), c_size)] + log.info("::: Creating reference chunks with uniform read amounts...") + + # ify the number of chunks + ref_chunks = sort_keys_by_approx_weight( + input_dict=references_m, scale=1.5, num_cores=threads + ) + log.info(f"::: Created {len(ref_chunks):,} chunks") + ref_chunks = random.sample(ref_chunks, len(ref_chunks)) params = zip([bam] * len(ref_chunks), ref_chunks) try: - logging.info( - f"Processing {len(ref_chunks):,} chunks of {c_size:,} references each" - ) + logging.info(f"Processing {len(ref_chunks):,} chunks") if is_debug(): data = list( map( @@ -921,8 +939,8 @@ def process_bam( else: p = Pool( threads, - initializer=initializer, - initargs=([params, ref_lengths, scale],), + # initializer=initializer, + # initargs=([params, shared_dict],), ) data = list( @@ -962,6 +980,26 @@ def process_bam( return data +def write_to_file(alns, out_bam_file, header=None): + for aln in alns: + out_bam_file.write(pysam.AlignedSegment.fromstring(aln, header)) + + +def process_references_batch(references, bam, refs_idx, min_read_ani, threads=1): + alns = [] + with pysam.AlignmentFile(bam, "rb", threads=threads) as samfile: + for reference in references: + for aln in samfile.fetch( + reference=reference, multiple_iterators=False, until_eof=True + ): + ani_read = (1 - ((aln.get_tag("NM") / aln.infer_query_length()))) * 100 + if ani_read >= min_read_ani: + aln.reference_id = refs_idx[aln.reference_name] + alns.append(aln.to_string()) + + return alns + + def filter_reference_BAM( bam, df, @@ -1060,15 +1098,27 @@ def filter_reference_BAM( # prof.enable() logging.info("Saving filtered stats...") df_filtered.to_csv( - out_files["stats_filtered"], sep="\t", index=False, compression="gzip" + out_files["stats_filtered"], + sep="\t", + index=False, ) if out_files["bam_filtered"] is not None: logging.info("Writing filtered BAM file... (be patient)") - refs_dict = dict( - zip(df_filtered["reference"], df_filtered["bam_reference_length"]) - ) - (ref_names, ref_lengths) = zip(*refs_dict.items()) + # refs_dict = dict( + # zip(df_filtered["reference"], df_filtered["bam_reference_length"]) + # ) + references = df_filtered["reference"].tolist() + samfile = pysam.AlignmentFile(bam, "rb", threads=threads) + refs_dict = {x: samfile.get_reference_length(x) for x in references} + header = samfile.header + (ref_names, ref_lengths) = zip(*refs_dict.items()) + ref_dict_m = { + chrom.contig: chrom.mapped + for chrom in samfile.get_index_statistics() + if chrom.contig in ref_names + } + samfile.close() refs_idx = {sys.intern(str(x)): i for i, x in enumerate(ref_names)} if threads > 4: write_threads = 4 @@ -1083,28 +1133,84 @@ def filter_reference_BAM( threads=write_threads, ) - samfile = pysam.AlignmentFile(bam, "rb", threads=threads) - references = [x for x in samfile.references if x in refs_idx.keys()] - - logging.info( - f"::: Filtering {len(references):,} references sequentially..." + # logging.info( + # f"::: Filtering {len(references):,} references sequentially..." + # ) + # for reference in tqdm.tqdm( + # references, + # total=len(references), + # leave=False, + # ncols=80, + # desc="References processed", + # ): + # for aln in samfile.fetch( + # reference=reference, multiple_iterators=False, until_eof=True + # ): + # ani_read = ( + # 1 - ((aln.get_tag("NM") / aln.infer_query_length())) + # ) * 100 + # if ani_read >= min_read_ani: + # aln.reference_id = refs_idx[aln.reference_name] + # out_bam_file.write(aln) + num_cores = min(threads, cpu_count()) + # batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size + log.info("::: Creating reference chunks with uniform read amounts...") + ref_chunks = sort_keys_by_approx_weight( + input_dict=ref_dict_m, scale=1.5, num_cores=threads, verbose=False ) - for reference in tqdm.tqdm( - references, - total=len(references), - leave=False, - ncols=80, - desc="References processed", - ): - for aln in samfile.fetch( - reference=reference, multiple_iterators=False, until_eof=True + num_cores = min(num_cores, len(ref_chunks)) + log.info(f"::: Using {num_cores} cores to write {len(ref_chunks)} chunk(s)") + + # with Manager() as manager: + # # Use Manager to create a read-only proxy for the dictionary + # #refs_idx = manager.dict(refs_idx) + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_cores + ) as executor: + # Use ProcessPoolExecutor to parallelize the processing of references in batches + futures = [] + for batch_references in tqdm.tqdm( + ref_chunks, + total=len(ref_chunks), + desc="Submitted batches", + unit="batch", + ncols=80, + leave=False, + disable=is_debug(), ): - ani_read = ( - 1 - ((aln.get_tag("NM") / aln.infer_query_length())) - ) * 100 - if ani_read >= min_read_ani: - aln.reference_id = refs_idx[aln.reference_name] - out_bam_file.write(aln) + future = executor.submit( + process_references_batch, + batch_references, + bam, + refs_idx, + min_read_ani=min_read_ani, + ) + futures.append(future) # Store the future + + # Use a while loop to continuously check for completed futures + log.info("::: Collecting batches...") + + completion_progress_bar = tqdm.tqdm( + total=len(futures), + desc="Completed", + unit="batch", + disable=is_debug(), + ncols=80, + leave=False, + ) + completed_count = 0 + + # Use as_completed to iterate over completed futures as they become available + for completed_future in concurrent.futures.as_completed(futures): + alns = completed_future.result() + write_to_file(alns=alns, out_bam_file=out_bam_file, header=header) + # Update the progress bar for each completed write + completion_progress_bar.update(1) + completed_count += 1 + completed_future.cancel() # Cancel the future to free memory + gc.collect() # Force garbage collection + completion_progress_bar.close() out_bam_file.close() # prof.disable() # # print profiling output @@ -1124,6 +1230,7 @@ def filter_reference_BAM( out_files["bam_filtered_tmp"], ) else: + logging.info("Sorting BAM file by coordinates...") pysam.sort( "-@", str(threads), diff --git a/bam_filter/utils.py b/bam_filter/utils.py index ea7cde7..397cde9 100644 --- a/bam_filter/utils.py +++ b/bam_filter/utils.py @@ -16,20 +16,153 @@ import numpy as np from pathlib import Path import pysam +import math +import tempfile + log = logging.getLogger("my_logger") log.setLevel(logging.INFO) timestr = time.strftime("%Y%m%d-%H%M%S") +def handle_warning(message, category, filename, lineno, file=None, line=None): + print("A warning occurred:") + print(message) + print("Do you wish to continue?") + + while True: + response = input("y/n: ").lower() + if response not in {"y", "n"}: + print("Not understood.") + else: + break + + if response == "n": + raise category(message) + + +# Check if the temporary directory exists, if not, create it +def check_tmp_dir_exists(tmpdir): + if tmpdir is None: + tmpdir = tempfile.TemporaryDirectory(dir=os.getcwd()) + else: + if not os.path.exists(tmpdir): + log.error(f"Temporary directory {tmpdir} does not exist") + exit(1) + tmpdir = tempfile.TemporaryDirectory(dir=os.path.abspath(tmpdir)) + return tmpdir + + def is_debug(): return logging.getLogger("my_logger").getEffectiveLevel() == logging.DEBUG +def sort_keys_by_approx_weight( + input_dict, scale=1, num_cores=1, refinement_steps=10, verbose=False +): + # Check for division by zero + if scale == 0: + raise ValueError("Scale cannot be zero.") + + # Calculate the target weight for each chunk + target_weight = int(scale * max(input_dict.values())) + + # Sort keys by their weights in descending order + sorted_keys = sorted(input_dict, key=lambda k: input_dict[k], reverse=True) + + # Calculate total weight + total_weight = sum(input_dict.values()) + + # Calculate the number of chunks based on the target weight + num_chunks = max(1, math.ceil(total_weight / target_weight)) + + # Initialize chunks with their total weights + chunks = [[] for _ in range(num_chunks)] + total_weights = [0] * num_chunks + + # Create a progress bar + progress_bar = tqdm.tqdm( + total=len(sorted_keys), + desc="Distributing keys", + unit="k", + unit_scale=True, + unit_divisor=1000, + disable=False, # Replace with your logic or a boolean value + leave=False, + ncols=80, + ) + + # Distribute keys into chunks + for key in sorted_keys: + min_chunk = min(range(num_chunks), key=lambda i: total_weights[i]) + chunks[min_chunk].append(key) + total_weights[min_chunk] += input_dict[key] + + # Update the progress bar + progress_bar.update(1) + + # Close the progress bar + progress_bar.close() + + # Remove empty chunks + chunks = [chunk for chunk in chunks if chunk] + + # Ensure a balanced number of reads in each chunk + max_reads = max(len(chunk) for chunk in chunks) + for chunk in chunks: + while len(chunk) < max_reads: + # Add elements from other chunks if needed + other_chunk = next((c for c in chunks if c != chunk and c), None) + if other_chunk: + element = other_chunk.pop(0) + chunk.append(element) + + # Initial balance + initial_balance = max(total_weights) - min(total_weights) + + # Refinement step + for _ in range(refinement_steps): + # Sort chunks by their total weights in ascending order + sorted_chunks = sorted(range(num_chunks), key=lambda i: total_weights[i]) + + # Move keys between chunks to minimize the difference in total weights + for src_chunk in sorted_chunks[:-1]: + dest_chunk = sorted_chunks[-1] + if total_weights[dest_chunk] - total_weights[src_chunk] > 1: + # Move one key from dest_chunk to src_chunk + key_to_move = chunks[dest_chunk].pop(0) + chunks[src_chunk].append(key_to_move) + total_weights[src_chunk] += input_dict[key_to_move] + total_weights[dest_chunk] -= input_dict[key_to_move] + + # Check for improvement in balance + current_balance = max(total_weights) - min(total_weights) + if current_balance >= initial_balance: + break # No improvement, exit the loop + + # Update initial balance for the next iteration + initial_balance = current_balance + + chunks = [chunk for chunk in chunks if chunk] + + # Print the min, max, and average weight of each chunk + if verbose: + for i, chunk in enumerate(chunks, 1): + chunk_weights = [input_dict[key] for key in chunk] + min_weight = min(chunk_weights) + max_weight = max(chunk_weights) + avg_weight = sum(chunk_weights) / len(chunk_weights) + print( + f"Chunk {i}: Total = {sum(chunk_weights)}, Min Weight = {min_weight}, Max Weight = {max_weight}, Average Weight = {avg_weight}" + ) + + return chunks + + def create_empty_output_files(out_files): for key, value in out_files.items(): if value is not None: - if key == "bam_filtered": + if key == "bam_filtered" or key == "bam_reassigned": create_empty_bam(value) elif ( key == "bam_filtered_tmp" or key == "bam_tmp" or key == "bam_tmp_sorted" @@ -162,6 +295,78 @@ def is_valid_file(parser, arg, var): return arg +# From https://stackoverflow.com/a/59617044/15704171 +def convert_list_to_str(lst): + n = len(lst) + if not n: + return "" + if n == 1: + return lst[0] + return ", ".join(lst[:-1]) + f" or {lst[-1]}" + + +lca_ranks = [ + "superkingdom", + "domain", + "lineage", + "kingdom", + "subkingdom", + "superphylum", + "phylum", + "subphylum", + "superclass", + "class", + "subclass", + "infraclass", + "clade", + "cohort", + "subcohort", + "superorder", + "order", + "suborder", + "infraorder", + "parvorder", + "superfamily", + "family", + "subfamily", + "tribe", + "subtribe", + "infratribe", + "genus", + "subgenus", + "section", + "series", + "subseries", + "subsection", + "species", + "species group", + "species subgroup", + "subspecies", + "varietas", + "morph", + "subvariety", + "forma", + "forma specialis", + "biotype", + "genotype", + "isolate", + "pathogroup", + "serogroup", + "serotype", + "strain", +] + + +def check_lca_ranks(val, parser, var): + value = str(val) + if value in lca_ranks: + return value + else: + parser.error( + f"argument {var}: Invalid value {value}. Filter has to be one of {convert_list_to_str(lca_ranks)}" + ) + + defaults = { "min_read_length": 30, "min_read_count": 3, @@ -183,10 +388,16 @@ def is_valid_file(parser, arg, var): "stats": None, "stats_filtered": None, "bam_filtered": None, + "bam_reassigned": None, "knee_plot": None, "read_length_freqs": None, "read_hits_count": None, "tmp_dir": None, + "reassign_iters": 25, + "reassign_scale": 0.9, + "reassign_target": "mapq", + "rank_lca": "species", + "lca_summary": None, } help_msg = { @@ -215,6 +426,7 @@ def is_valid_file(parser, arg, var): "stats": "Save a TSV file with the statistics for each reference", "stats_filtered": "Save a TSV file with the statistics for each reference after filtering", "bam_filtered": "Save a BAM file with the references that passed the filtering criteria", + "bam_reassigned": "Save a BAM file without multimapping reads", "coverage_plots": "Folder where to save genome coverage plots", "knee_plot": "Plot knee plot", "sort_by_name": "Sort by read names", @@ -225,6 +437,17 @@ def is_valid_file(parser, arg, var): "debug": "Print debug messages", "reference_lengths": "File with references lengths", "low_memory": "Activate the low memory mode", + "reassign": "Run an EM algorithm to reassign reads to references", + "reassign_method": "Method for the EM algorithm", + "reassign_iters": "Number of iterations for the EM algorithm", + "reassign_scale": "Scale to select the best weithing alignments", + "reassign_target": "Which target to use for the EM algorith, Only mapq or pident.", + "lca": "Calculate LCA for each read and estimate abundances", + "names": "Names dmp file from taxonomy", + "nodes": "Nodes dmp file from taxonomy", + "acc2taxid": "acc2taxid file from taxonomy", + "rank_lca": "Rank to use for LCA calculation", + "lca_summary": "Save a TSV file with the LCA summary", "version": "Print program version", } @@ -234,18 +457,53 @@ def get_arguments(argv=None): description="A simple tool to calculate metrics from a BAM file and filter with uneven coverage.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - # add subparser for filtering options: - filter_args = parser.add_argument_group("filtering arguments") - misc_args = parser.add_argument_group("miscellaneous arguments") - out_args = parser.add_argument_group("output arguments") parser.add_argument( + "--version", + action="version", + version="%(prog)s " + __version__, + help=help_msg["version"], + ) + parser.add_argument( + "--debug", dest="debug", action="store_true", help=help_msg["debug"] + ) + + sub_parsers = parser.add_subparsers( + help="positional arguments", + dest="action", + ) + + # Create parent subparser. Note `add_help=False` and creation via `argparse.` + parent_parser = argparse.ArgumentParser(add_help=False) + + required = parent_parser.add_argument_group("required arguments") + required.add_argument( "--bam", required=True, dest="bam", type=lambda x: is_valid_file(parser, x, "bam"), help=help_msg["bam"], ) - parser.add_argument( + # required = parent_parser.add_argument_group("required arguments") + optional = parent_parser.add_argument_group("optional arguments") + optional.add_argument( + "-p", + "--prefix", + type=str, + default=defaults["prefix"], + metavar="STR", + dest="prefix", + help=help_msg["prefix"], + ) + optional.add_argument( + "-r", + "--reference-lengths", + type=lambda x: is_valid_file(parser, x, "reference_lengths"), + metavar="FILE", + default=defaults["reference_lengths"], + dest="reference_lengths", + help=help_msg["reference_lengths"], + ) + optional.add_argument( "-t", "--threads", type=lambda x: int( @@ -256,7 +514,155 @@ def get_arguments(argv=None): default=1, help=help_msg["threads"], ) - misc_args.add_argument( + # Create the parser sub-command for db creation + parser_reassign = sub_parsers.add_parser( + "reassign", + help="Reassign reads to references using an EM algorithm", + parents=[parent_parser], + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + # create the parser sub-commands + parser_filter = sub_parsers.add_parser( + "filter", + help="Filter references based on coverage and other metrics", + parents=[parent_parser], + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser_lca = sub_parsers.add_parser( + "lca", + help="Calculate LCA for each read and estimate abundances at each rank", + parents=[parent_parser], + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # createdb_required_args = parser_createdb.add_argument_group("required arguments") + # reassign_required_args = parser_reassign.add_argument_group( + # "Re-assign required arguments" + # ) + reassign_optional_args = parser_reassign.add_argument_group( + "Re-assign optional arguments" + ) + + filter_required_args = parser_filter.add_argument_group("Filter required arguments") + # filter_optional_args = parser_filter.add_argument_group("Filter optional arguments") + + # lca_required_args = parser_lca.add_argument_group("LCA required arguments") + lca_optional_args = parser_lca.add_argument_group("LCA optional arguments") + + # add subparser for filtering options: + # reassign_args = parser.add_argument_group("reassign arguments") + filtering_filt_args = parser_filter.add_argument_group("filtering arguments") + # lca_args = parser.add_argument_group("lca arguments") + misc_filter_args = parser_filter.add_argument_group("miscellaneous arguments") + out_filter_args = parser_filter.add_argument_group("output arguments") + # parser.add_argument( + # "--bam", + # required=True, + # dest="bam", + # type=lambda x: is_valid_file(parser, x, "bam"), + # help=help_msg["bam"], + # ) + # parser.add_argument( + # "-t", + # "--threads", + # type=lambda x: int( + # check_values(x, minval=1, maxval=1000, parser=parser, var="--threads") + # ), + # dest="threads", + # metavar="INT", + # default=1, + # help=help_msg["threads"], + # ) + + reassign_optional_args.add_argument( + "-i", + "--iters", + type=lambda x: int( + check_values( + x, minval=0, maxval=100000, parser=parser, var="--reassign-n-iters" + ) + ), + metavar="INT", + default=defaults["reassign_iters"], + dest="reassign_iters", + help=help_msg["reassign_iters"], + ) + reassign_optional_args.add_argument( + "-s", + "--scale", + type=lambda x: float( + check_values(x, minval=0, maxval=1, parser=parser, var="--scale") + ), + metavar="FLOAT", + default=defaults["reassign_scale"], + dest="reassign_scale", + help=help_msg["reassign_scale"], + ) + reassign_optional_args.add_argument( + "-A", + "--min-read-ani", + type=lambda x: float( + check_values(x, minval=0, maxval=100, parser=parser, var="--min-read-ani") + ), + metavar="FLOAT", + default=defaults["min_read_ani"], + dest="min_read_ani", + help=help_msg["min_read_ani"], + ) + reassign_optional_args.add_argument( + "-l", + "--min-read-length", + type=lambda x: int( + check_values( + x, minval=1, maxval=100000, parser=parser, var="--min-read-length" + ) + ), + default=defaults["min_read_length"], + metavar="INT", + dest="min_read_length", + help=help_msg["min_read_length"], + ) + reassign_optional_args.add_argument( + "-n", + "--min-read-count", + type=lambda x: int( + check_values( + x, minval=1, maxval=np.Inf, parser=parser, var="--min-read-count" + ) + ), + default=defaults["min_read_count"], + metavar="INT", + dest="min_read_count", + help=help_msg["min_read_count"], + ) + reassign_optional_args.add_argument( + "-o", + "--out-bam", + dest="bam_reassigned", + default=defaults["bam_reassigned"], + metavar="FILE", + type=str, + nargs="?", + const="", + help=help_msg["bam_reassigned"], + ) + reassign_optional_args.add_argument( + "-m", + "--sort-memory", + type=lambda x: check_suffix(x, parser=parser, var="--sort-memory"), + default=defaults["sort_memory"], + metavar="STR", + dest="sort_memory", + help=help_msg["sort_memory"], + ) + reassign_optional_args.add_argument( + "-N", + "--sort-by-name", + dest="sort_by_name", + action="store_true", + help=help_msg["sort_by_name"], + ) + misc_filter_args.add_argument( "--reference-trim-length", type=lambda x: int( check_values( @@ -268,7 +674,7 @@ def get_arguments(argv=None): default=0, help=help_msg["trim_ends"], ) - misc_args.add_argument( + misc_filter_args.add_argument( "--trim-min", type=lambda x: int( check_values(x, minval=0, maxval=100, parser=parser, var="--trim-min") @@ -278,7 +684,7 @@ def get_arguments(argv=None): default=10, help=help_msg["trim_min"], ) - misc_args.add_argument( + misc_filter_args.add_argument( "--trim-max", type=lambda x: int( check_values(x, minval=0, maxval=100, parser=parser, var="--trim-max") @@ -288,16 +694,7 @@ def get_arguments(argv=None): default=90, help=help_msg["trim_max"], ) - parser.add_argument( - "-p", - "--prefix", - type=str, - default=defaults["prefix"], - metavar="STR", - dest="prefix", - help=help_msg["prefix"], - ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-A", "--min-read-ani", type=lambda x: float( @@ -308,7 +705,7 @@ def get_arguments(argv=None): dest="min_read_ani", help=help_msg["min_read_ani"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-l", "--min-read-length", type=lambda x: int( @@ -321,7 +718,7 @@ def get_arguments(argv=None): dest="min_read_length", help=help_msg["min_read_length"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-n", "--min-read-count", type=lambda x: int( @@ -334,7 +731,7 @@ def get_arguments(argv=None): dest="min_read_count", help=help_msg["min_read_count"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-b", "--min-expected-breadth-ratio", type=lambda x: float( @@ -347,7 +744,7 @@ def get_arguments(argv=None): dest="min_expected_breadth_ratio", help=help_msg["min_expected_breadth_ratio"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-e", "--min-normalized-entropy", type=lambda x: check_values_auto( @@ -358,7 +755,7 @@ def get_arguments(argv=None): dest="min_norm_entropy", help=help_msg["min_norm_entropy"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-g", "--min-normalized-gini", type=lambda x: check_values_auto( @@ -369,7 +766,7 @@ def get_arguments(argv=None): dest="min_norm_gini", help=help_msg["min_norm_gini"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-B", "--min-breadth", type=lambda x: float( @@ -380,7 +777,7 @@ def get_arguments(argv=None): dest="min_breadth", help=help_msg["min_breadth"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-a", "--min-avg-read-ani", type=lambda x: float( @@ -393,7 +790,7 @@ def get_arguments(argv=None): dest="min_avg_read_ani", help=help_msg["min_avg_read_ani"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-c", "--min-coverage-evenness", type=lambda x: float( @@ -406,7 +803,7 @@ def get_arguments(argv=None): dest="min_coverage_evenness", help=help_msg["min_coverage_evenness"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-V", "--min-coeff-var", type=lambda x: float( @@ -419,7 +816,7 @@ def get_arguments(argv=None): dest="min_coeff_var", help=help_msg["min_coeff_var"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "-C", "--min-coverage-mean", type=lambda x: float( @@ -432,13 +829,13 @@ def get_arguments(argv=None): dest="min_coverage_mean", help=help_msg["min_coverage_mean"], ) - filter_args.add_argument( + filtering_filt_args.add_argument( "--include-low-detection", dest="transform_cov_evenness", action="store_true", help=help_msg["transform_cov_evenness"], ) - parser.add_argument( + misc_filter_args.add_argument( "-m", "--sort-memory", type=lambda x: check_suffix(x, parser=parser, var="--sort-memory"), @@ -447,20 +844,20 @@ def get_arguments(argv=None): dest="sort_memory", help=help_msg["sort_memory"], ) - parser.add_argument( + misc_filter_args.add_argument( "-N", "--sort-by-name", dest="sort_by_name", action="store_true", help=help_msg["sort_by_name"], ) - parser.add_argument( + misc_filter_args.add_argument( "--disable-sort", dest="disable_sort", action="store_true", help=help_msg["disable_sort"], ) - parser.add_argument( + misc_filter_args.add_argument( "--scale", type=lambda x: check_suffix(x, parser=parser, var="--scale"), default=defaults["scale"], @@ -469,16 +866,16 @@ def get_arguments(argv=None): help=help_msg["scale"], ) # reference_lengths - parser.add_argument( - "-r", - "--reference-lengths", - type=lambda x: is_valid_file(parser, x, "reference_lengths"), - metavar="FILE", - default=defaults["reference_lengths"], - dest="reference_lengths", - help=help_msg["reference_lengths"], - ) - out_args.add_argument( + # filter_optional_args.add_argument( + # "-r", + # "--reference-lengths", + # type=lambda x: is_valid_file(parser, x, "reference_lengths"), + # metavar="FILE", + # default=defaults["reference_lengths"], + # dest="reference_lengths", + # help=help_msg["reference_lengths"], + # ) + filter_required_args.add_argument( "--stats", dest="stats", default=defaults["stats"], @@ -489,7 +886,7 @@ def get_arguments(argv=None): required=True, help=help_msg["stats"], ) - out_args.add_argument( + out_filter_args.add_argument( "--stats-filtered", dest="stats_filtered", default=defaults["stats_filtered"], @@ -499,7 +896,7 @@ def get_arguments(argv=None): const="", help=help_msg["stats_filtered"], ) - out_args.add_argument( + out_filter_args.add_argument( "--bam-filtered", dest="bam_filtered", default=defaults["bam_filtered"], @@ -509,7 +906,7 @@ def get_arguments(argv=None): const="", help=help_msg["bam_filtered"], ) - out_args.add_argument( + out_filter_args.add_argument( "--read-length-freqs", dest="read_length_freqs", default=defaults["read_length_freqs"], @@ -519,7 +916,7 @@ def get_arguments(argv=None): const="", help=help_msg["read_length_freqs"], ) - out_args.add_argument( + out_filter_args.add_argument( "--read-hits-count", dest="read_hits_count", default=defaults["read_hits_count"], @@ -529,7 +926,7 @@ def get_arguments(argv=None): const="", help=help_msg["read_hits_count"], ) - out_args.add_argument( + out_filter_args.add_argument( "--knee-plot", dest="knee_plot", default=defaults["knee_plot"], @@ -539,7 +936,7 @@ def get_arguments(argv=None): const="", help=help_msg["knee_plot"], ) - out_args.add_argument( + out_filter_args.add_argument( "--coverage-plots", dest="coverage_plots", metavar="FILE", @@ -549,17 +946,17 @@ def get_arguments(argv=None): const="", help=help_msg["coverage_plots"], ) - parser.add_argument( - "--chunk-size", - type=lambda x: int( - check_values(x, minval=1, maxval=100000, parser=parser, var="--chunk-size") - ), - default=defaults["chunk_size"], - metavar="INT", - dest="chunk_size", - help=help_msg["chunk_size"], - ) - parser.add_argument( + # parser.add_argument( + # "--chunk-size", + # type=lambda x: int( + # check_values(x, minval=1, maxval=100000, parser=parser, var="--chunk-size") + # ), + # default=defaults["chunk_size"], + # metavar="INT", + # dest="chunk_size", + # help=help_msg["chunk_size"], + # ) + misc_filter_args.add_argument( "--tmp-dir", type=str, default=defaults["tmp_dir"], @@ -567,20 +964,59 @@ def get_arguments(argv=None): dest="tmp_dir", help=help_msg["tmp_dir"], ) - parser.add_argument( + misc_filter_args.add_argument( "--low-memory", dest="low_memory", action="store_true", help=help_msg["low_memory"], ) - parser.add_argument( - "--debug", dest="debug", action="store_true", help=help_msg["debug"] + + lca_optional_args.add_argument( + "--names", + metavar="FILE", + type=lambda x: is_valid_file(parser, x, "names"), + dest="names", + help=help_msg["names"], ) - parser.add_argument( - "--version", - action="version", - version="%(prog)s " + __version__, - help=help_msg["version"], + lca_optional_args.add_argument( + "--nodes", + metavar="FILE", + type=lambda x: is_valid_file(parser, x, "nodes"), + dest="nodes", + help=help_msg["nodes"], + ) + lca_optional_args.add_argument( + "--acc2taxid", + metavar="FILE", + type=lambda x: is_valid_file(parser, x, "acc2taxid"), + dest="acc2taxid", + help=help_msg["acc2taxid"], + ) + lca_optional_args.add_argument( + "--rank-lca", + metavar="STR", + type=lambda x: str(check_lca_ranks(x, parser=parser, var="--rank-lca")), + default=defaults["rank_lca"], + dest="rank_lca", + help=help_msg["rank_lca"], + ) + lca_optional_args.add_argument( + "--lca-summary", + dest="lca_summary", + metavar="FILE", + default=defaults["lca_summary"], + type=str, + nargs="?", + const="", + help=help_msg["lca_summary"], + ) + lca_optional_args.add_argument( + "--scale", + type=lambda x: check_suffix(x, parser=parser, var="--scale"), + default=defaults["scale"], + dest="scale", + metavar="STR", + help=help_msg["scale"], ) args = parser.parse_args(None if sys.argv[1:] else ["-h"]) return args @@ -711,46 +1147,59 @@ def calc_chunksize(n_workers, len_iterable, factor=4): # } # return out_files def create_output_files( - prefix, bam, - stats, - stats_filtered, - bam_filtered, - read_length_freqs, - read_hits_count, - knee_plot, - coverage_plots, tmp_dir, + prefix=None, + stats="", + stats_filtered="", + bam_reassigned="", + bam_filtered="", + read_length_freqs="", + read_hits_count="", + knee_plot="", + coverage_plots="", + lca_summary="", ): if prefix is None: prefix = Path(bam).with_suffix("").name - if stats == "": + if tmp_dir is not None: + tmp_dir = tmp_dir.name + + if stats == "" or stats is None: stats = f"{prefix}_stats.tsv.gz" - if stats_filtered == "": + if stats_filtered == "" or stats_filtered is None: stats_filtered = f"{prefix}_stats-filtered.tsv.gz" - if bam_filtered == "": + if bam_filtered == "" or bam_filtered is None: bam_filtered = f"{prefix}.filtered.bam" - if read_length_freqs == "": + if bam_reassigned == "" or bam_reassigned is None: + bam_reassigned = f"{prefix}.reassigned.bam" + if read_length_freqs == "" or read_length_freqs is None: read_length_freqs = f"{prefix}_read-length-freqs.json" - if read_hits_count == "": + if read_hits_count == "" or read_hits_count is None: read_hits_count = f"{prefix}_read-hits-count.tsv.gz" - if knee_plot == "": + if knee_plot == "" or knee_plot is None: knee_plot = f"{prefix}_knee-plot.png" - if coverage_plots == "": + if coverage_plots == "" or coverage_plots is None: coverage_plots = f"{prefix}_coverage-plots" + if lca_summary == "" or lca_summary is None: + lca_summary = f"{prefix}_lca-summary.tsv.gz" # create output files out_files = { "stats": stats, "stats_filtered": stats_filtered, - "bam_filtered_tmp": f"{tmp_dir.name}/{prefix}.filtered.tmp.bam", - "bam_tmp": f"{tmp_dir.name}/{prefix}.tmp.bam", - "bam_tmp_sorted": f"{tmp_dir.name}/{prefix}.tmp.sorted.bam", + "bam_filtered_tmp": f"{tmp_dir}/{prefix}.filtered.tmp.bam", + "bam_tmp": f"{tmp_dir}/{prefix}.tmp.bam", + "bam_tmp_sorted": f"{tmp_dir}/{prefix}.tmp.sorted.bam", "bam_filtered": bam_filtered, + "bam_reassigned_tmp": f"{tmp_dir}/{prefix}.reassigned.tmp.bam", + "bam_reassigned_sorted": f"{tmp_dir}/{prefix}.reassigned.sorted.bam", + "bam_reassigned": bam_reassigned, "read_length_freqs": read_length_freqs, "read_hits_count": read_hits_count, "knee_plot": knee_plot, "coverage_plot_dir": coverage_plots, + "lca_summary": lca_summary, } return out_files