Skip to content

Commit

Permalink
Split steps
Browse files Browse the repository at this point in the history
  • Loading branch information
genomewalker committed Nov 25, 2023
1 parent 2383d06 commit 195fd8d
Show file tree
Hide file tree
Showing 6 changed files with 2,327 additions and 384 deletions.
259 changes: 12 additions & 247 deletions bam_filter/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#
"""
This program is free software: you can redistribute it and/or modify it under the terms of the GNU
General Public License as published by the Free Software Foundation, either version 3 of the
Expand All @@ -11,271 +12,35 @@
see <https://www.gnu.org/licenses/>.
"""


import logging
import pandas as pd
from bam_filter.sam_utils import process_bam, filter_reference_BAM, check_bam_file

from bam_filter.reassign import reassign
from bam_filter.filter import filter_references
from bam_filter.lca import do_lca
from bam_filter.utils import (
get_arguments,
create_output_files,
concat_df,
)
from bam_filter.entropy import find_knee
import json
import warnings
from collections import Counter
from functools import reduce
import os
import tempfile

log = logging.getLogger("my_logger")


def handle_warning(message, category, filename, lineno, file=None, line=None):
print("A warning occurred:")
print(message)
print("Do you wish to continue?")

while True:
response = input("y/n: ").lower()
if response not in {"y", "n"}:
print("Not understood.")
else:
break

if response == "n":
raise category(message)


def obj_dict(obj):
return obj.__dict__


def get_summary(obj):
return obj.to_summary()


def get_lens(obj):
return obj.get_read_length_freqs()


# Check if the temporary directory exists, if not, create it
def check_tmp_dir_exists(tmpdir):
if tmpdir is None:
tmpdir = tempfile.TemporaryDirectory(dir=os.getcwd())
else:
if not os.path.exists(tmpdir):
log.error(f"Temporary directory {tmpdir} does not exist")
exit(1)
tmpdir = tempfile.TemporaryDirectory(dir=os.path.abspath(tmpdir))
return tmpdir


def main():
logging.basicConfig(
level=logging.DEBUG,
format="%(levelname)s ::: %(asctime)s ::: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

logging.getLogger("urllib3").setLevel(logging.WARNING)
args = get_arguments()

tmp_dir = check_tmp_dir_exists(args.tmp_dir)
log.info("Temporary directory: %s", tmp_dir.name)
if args.trim_min >= args.trim_max:
log.error("trim_min must be less than trim_max")
exit(1)

logging.getLogger("my_logger").setLevel(
logging.DEBUG if args.debug else logging.INFO
)

if args.debug:
warnings.showwarning = handle_warning
else:
warnings.filterwarnings("ignore")

plot_coverage = False
if args.coverage_plots is not None:
plot_coverage = True

if args.stats_filtered is None and (args.bam_filtered is not None):
logging.error(
"You need to specify a filtereds stats file to obtain the filtered BAM file"
)
exit(1)

out_files = create_output_files(
prefix=args.prefix,
bam=args.bam,
stats=args.stats,
stats_filtered=args.stats_filtered,
bam_filtered=args.bam_filtered,
read_length_freqs=args.read_length_freqs,
read_hits_count=args.read_hits_count,
knee_plot=args.knee_plot,
coverage_plots=args.coverage_plots,
tmp_dir=tmp_dir,
)

bam = check_bam_file(
bam=args.bam,
threads=args.threads,
reference_lengths=args.reference_lengths,
sort_memory=args.sort_memory,
)

data = process_bam(
bam=bam,
threads=args.threads,
reference_lengths=args.reference_lengths,
min_read_count=args.min_read_count,
min_read_ani=args.min_read_ani,
trim_ends=args.trim_ends,
trim_min=args.trim_min,
trim_max=args.trim_max,
scale=args.scale,
plot=plot_coverage,
plots_dir=out_files["coverage_plot_dir"],
chunksize=args.chunk_size,
read_length_freqs=args.read_length_freqs,
output_files=out_files,
sort_memory=args.sort_memory,
low_memory=args.low_memory,
)
if args.low_memory:
bam = out_files["bam_tmp_sorted"]

logging.info("Reducing results to a single dataframe")
# data = list(filter(None, data))
data_df = [x[0] for x in data if x[0] is not None]
data_df = concat_df(data_df)

if args.read_length_freqs is not None:
logging.info("Calculating read length frequencies...")
lens = [x[1] for x in data if x[1] is not None]
lens = json.dumps(lens, default=obj_dict, ensure_ascii=False, indent=4)
with open(out_files["read_length_freqs"], "w", encoding="utf-8") as outfile:
print(lens, file=outfile)

if args.read_hits_count is not None:
logging.info("Calculating read hits counts...")
hits = [x[2] for x in data if x[2] is not None]

# merge dicts and sum values
hits = reduce(lambda x, y: x.update(y) or x, (Counter(dict(x)) for x in hits))
# hits = sum(map(Counter, hits), Counter())

# convert dict to dataframe
hits = (
pd.DataFrame.from_dict(hits, orient="index", columns=["count"])
.rename_axis("read")
.reset_index()
.sort_values(by="count", ascending=False)
)

hits.to_csv(
out_files["read_hits_count"], sep="\t", index=False, compression="gzip"
)

logging.info(f"Writing reference statistics to {out_files['stats']}")
data_df.to_csv(out_files["stats"], sep="\t", index=False, compression="gzip")

if args.min_norm_entropy is None or args.min_norm_gini is None:
filter_conditions = {
"min_read_length": args.min_read_length,
"min_read_count": args.min_read_count,
"min_expected_breadth_ratio": args.min_expected_breadth_ratio,
"min_breadth": args.min_breadth,
"min_avg_read_ani": args.min_avg_read_ani,
"min_coverage_evenness": args.min_coverage_evenness,
"min_coeff_var": args.min_coeff_var,
"min_coverage_mean": args.min_coverage_mean,
}
elif args.min_norm_entropy == "auto" or args.min_norm_gini == "auto":
if data_df.shape[0] > 1:
min_norm_gini, min_norm_entropy = find_knee(
data_df, out_plot_name=out_files["knee_plot"]
)

if min_norm_gini is None or min_norm_entropy is None:
logging.warning(
"Could not find knee in entropy plot. Disabling filtering by entropy/gini inequality."
)
filter_conditions = {
"min_read_length": args.min_read_length,
"min_read_count": args.min_read_count,
"min_expected_breadth_ratio": args.min_expected_breadth_ratio,
"min_breadth": args.min_breadth,
"min_avg_read_ani": args.min_avg_read_ani,
"min_coverage_evenness": args.min_coverage_evenness,
"min_coeff_var": args.min_coeff_var,
"min_coverage_mean": args.min_coverage_mean,
}
else:
filter_conditions = {
"min_read_length": args.min_read_length,
"min_read_count": args.min_read_count,
"min_expected_breadth_ratio": args.min_expected_breadth_ratio,
"min_breadth": args.min_breadth,
"min_avg_read_ani": args.min_avg_read_ani,
"min_coverage_evenness": args.min_coverage_evenness,
"min_coeff_var": args.min_coeff_var,
"min_coverage_mean": args.min_coverage_mean,
"min_norm_entropy": min_norm_entropy,
"min_norm_gini": min_norm_gini,
}
else:
min_norm_gini = 0.5
min_norm_entropy = 0.75
logging.warning(
f"There's only one genome. Using min_norm_gini={min_norm_gini} and min_norm_entropy={min_norm_entropy}. Please check the results."
)
filter_conditions = {
"min_read_length": args.min_read_length,
"min_read_count": args.min_read_count,
"min_expected_breadth_ratio": args.min_expected_breadth_ratio,
"min_breadth": args.min_breadth,
"min_avg_read_ani": args.min_avg_read_ani,
"min_coverage_evenness": args.min_coverage_evenness,
"min_coeff_var": args.min_coeff_var,
"min_coverage_mean": args.min_coverage_mean,
"min_norm_entropy": min_norm_entropy,
"min_norm_gini": min_norm_gini,
}
else:
min_norm_gini, min_norm_entropy = args.min_norm_gini, args.min_norm_entropy
filter_conditions = {
"min_read_length": args.min_read_length,
"min_read_count": args.min_read_count,
"min_expected_breadth_ratio": args.min_expected_breadth_ratio,
"min_breadth": args.min_breadth,
"min_avg_read_ani": args.min_avg_read_ani,
"min_coverage_evenness": args.min_coverage_evenness,
"min_coeff_var": args.min_coeff_var,
"min_coverage_mean": args.min_coverage_mean,
"min_norm_entropy": min_norm_entropy,
"min_norm_gini": min_norm_gini,
}

if args.stats_filtered is not None:
filter_reference_BAM(
bam=bam,
df=data_df,
filter_conditions=filter_conditions,
transform_cov_evenness=args.transform_cov_evenness,
threads=args.threads,
out_files=out_files,
sort_memory=args.sort_memory,
sort_by_name=args.sort_by_name,
min_read_ani=args.min_read_ani,
disable_sort=args.disable_sort,
)
else:
logging.info("Skipping filtering of reference BAM file.")
if args.low_memory:
os.remove(out_files["bam_tmp_sorted"])
logging.info("ALL DONE.")
if args.action == "reassign":
reassign(args)
elif args.action == "filter":
filter_references(args)
elif args.action == "lca":
do_lca(args)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 195fd8d

Please sign in to comment.