From 74f232125b1b3c29b98e3e9d784608c6ebaf2771 Mon Sep 17 00:00:00 2001 From: genomewalker Date: Thu, 21 Mar 2024 11:45:19 +0100 Subject: [PATCH 1/2] Fixed TAD abundance calculation for LCA --- bam_filter/lca.py | 13 +++++++++---- bam_filter/reassign.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bam_filter/lca.py b/bam_filter/lca.py index eea9114..52f0946 100644 --- a/bam_filter/lca.py +++ b/bam_filter/lca.py @@ -149,7 +149,7 @@ def create_tax_graph_w(tax_path, weight, lengths, ref_stats=None, scale=1_000_00 if weight <= ref_stats[target][0] and weight > ref_stats[target][1]: res.iloc[-1, res.columns.get_loc("weight")] = ref_stats[target][1] res.iloc[-1, res.columns.get_loc("norm_weight")] = round( - (ref_stats[target][1] / lengths[target]) * scale + (ref_stats[target][1] / ref_stats[target][2]) * scale ) else: res.iloc[-1, res.columns.get_loc("weight")] = weight @@ -377,7 +377,9 @@ def do_lca(args): if args.lca_stats: log.info("Loading reference stats...") ref_stats = pd.read_csv(args.lca_stats, sep="\t", index_col=False) - ref_stats = ref_stats[["reference", "n_reads", "n_reads_tad"]] + ref_stats = ref_stats[ + ["reference", "n_reads", "n_reads_tad", "coverage_mean_trunc_len"] + ] ref_stats = ref_stats[ref_stats["n_reads_tad"] > 0] if ref_stats.empty: del ref_stats @@ -387,10 +389,13 @@ def do_lca(args): ref_stats = dict( zip( ref_stats["reference"], - zip(ref_stats["n_reads"], ref_stats["n_reads_tad"]), + zip( + ref_stats["n_reads"], + ref_stats["n_reads_tad"], + ref_stats["coverage_mean_trunc_len"], + ), ) ) - log.info("Getting reference to read mapping") ref_chunks = sort_keys_by_approx_weight( references_m, scale=1, num_cores=threads, refinement_steps=100 diff --git a/bam_filter/reassign.py b/bam_filter/reassign.py index 2b1228d..86717b9 100644 --- a/bam_filter/reassign.py +++ b/bam_filter/reassign.py @@ -461,7 +461,7 @@ def get_bam_data( K_value=0.21 ): bam, references = parms - dt.options.progress.enabled = True + dt.options.progress.enabled = False dt.options.progress.clear_on_success = True if threads > 1: dt.options.nthreads = threads - 1 From c8876edb6d218e9059e69b549f1f43e7b917c6a2 Mon Sep 17 00:00:00 2001 From: genomewalker Date: Tue, 26 Mar 2024 17:12:31 +0100 Subject: [PATCH 2/2] Fixed threads writing --- bam_filter/reassign.py | 367 ++++++++++++++++++++++++++++------------ bam_filter/sam_utils.py | 3 +- bam_filter/utils.py | 10 +- 3 files changed, 271 insertions(+), 109 deletions(-) diff --git a/bam_filter/reassign.py b/bam_filter/reassign.py index 86717b9..42a9e75 100644 --- a/bam_filter/reassign.py +++ b/bam_filter/reassign.py @@ -298,7 +298,9 @@ def write_reassigned_bam( threads=write_threads, ) - num_cores = min(threads, cpu_count()) + # num_cores should be multiple of the write_threads + num_cores = threads // write_threads + # num_cores = min(threads, cpu_count()) # batch_size = len(references) // num_cores + 1 # Ensure non-zero batch size # batch_size = calc_chunksize(n_workers=num_cores, len_iterable=len(references)) log.info("::: Creating reference chunks with uniform read amounts...") @@ -307,7 +309,7 @@ def write_reassigned_bam( input_dict=ref_counts, scale=1, num_cores=num_cores ) num_cores = min(num_cores, len(ref_chunks)) - log.info(f"::: Using {num_cores} cores to write {len(ref_chunks)} chunk(s)") + log.info(f"::: Using {num_cores} processes to write {len(ref_chunks)} chunk(s)") with Manager() as manager: # Use Manager to create a read-only proxy for the dictionary @@ -427,6 +429,110 @@ def write_reassigned_bam( # Values from: # https://www.ncbi.nlm.nih.gov/IEB/ToolBox/CPP_DOC/lxr/source/src/algo/blast/core/blast_stat.c +# def calculate_alignment_score( +# num_matches, +# num_mismatches, +# num_gaps, +# gap_extensions, +# match_reward, +# mismatch_penalty, +# gap_open_penalty, +# gap_extension_penalty, +# lambda_value, +# K_value +# ): +# # Calculate the raw alignment score +# S = (num_matches * match_reward) - (num_mismatches * mismatch_penalty) - (num_gaps * gap_open_penalty) - (gap_extensions * gap_extension_penalty) + +# # Calculate the approximate bit score +# bit_score = (lambda_value * S - math.log(K_value)) / math.log(2) + +# return bit_score + +# def get_bam_data( +# parms, +# ref_lengths=None, +# percid=90, +# min_read_length=30, +# threads=1, +# match_reward=1, +# mismatch_penalty=-1, +# gap_open_penalty=1, +# gap_extension_penalty=2, +# lambda_value=1.02, +# K_value=0.21 +# ): +# bam, references = parms +# dt.options.progress.enabled = False +# dt.options.progress.clear_on_success = True +# if threads > 1: +# dt.options.nthreads = threads - 1 +# else: +# dt.options.nthreads = 1 +# if threads > 4: +# s_threads = 4 +# else: +# s_threads = threads +# samfile = pysam.AlignmentFile(bam, "rb", threads=s_threads) + +# results = [] +# reads = set() +# refs = set() +# empty_df = 0 + +# for reference in references: +# if ref_lengths is None: +# reference_length = int(samfile.get_reference_length(reference)) +# else: +# reference_length = int(ref_lengths[reference]) +# aln_data = [] +# for aln in samfile.fetch( +# contig=reference, multiple_iterators=False, until_eof=True +# ): +# # AS_value = aln.get_tag("AS") if aln.has_tag("AS") else None +# # XS_value = aln.get_tag("XS") if aln.has_tag("XS") else None +# query_length = ( +# aln.query_length if aln.query_length != 0 else aln.infer_query_length() +# ) + +# if query_length < min_read_length: +# continue + +# pident = (1 - ((aln.get_tag("NM") / query_length))) * 100 +# num_mismatches = aln.get_tag("NM") +# num_matches = query_length - num_mismatches +# num_gaps = aln.get_tag("XO") if aln.has_tag("XO") else 0 +# gap_extensions = aln.get_tag("XG") if aln.has_tag("XG") else 0 +# if pident >= percid: +# var = calculate_alignment_score(num_matches=num_matches, num_mismatches=num_mismatches, num_gaps=num_gaps, gap_extensions=gap_extensions, match_reward=match_reward, mismatch_penalty=mismatch_penalty, gap_open_penalty=gap_open_penalty, gap_extension_penalty=gap_extension_penalty, lambda_value=lambda_value, K_value=K_value) +# aln_data.append( +# ( +# aln.query_name, +# aln.reference_name, +# var, +# reference_length, +# ) +# ) +# reads.add(aln.query_name) +# refs.add(aln.reference_name) +# # remove duplicates +# # check if aln_data is not empty +# if len(aln_data) > 0: +# aln_data_dt = dt.Frame(aln_data) +# # "queryId", "subjectId", "bitScore", "slen" +# aln_data_dt.names = ["queryId", "subjectId", "var", "slen"] +# # remove duplicates and keep the ones with the largest score +# aln_data_dt = aln_data_dt[ +# :1, :, dt.by(dt.f.queryId, dt.f.subjectId), dt.sort(-dt.f.var) +# ] +# results.append(aln_data_dt) +# else: +# results.append(dt.Frame()) +# empty_df += 1 +# samfile.close() +# return (dt.rbind(results), reads, refs, empty_df) + + def calculate_alignment_score( num_matches, num_mismatches, @@ -436,14 +542,14 @@ def calculate_alignment_score( mismatch_penalty, gap_open_penalty, gap_extension_penalty, - lambda_value, - K_value + precomputed_factor, # This is lambda_value * match_reward / math.log(2) + precomputed_log_K # This is math.log(K_value) / math.log(2) ): - # Calculate the raw alignment score + # Calculate the raw alignment score with reduced arithmetic operations S = (num_matches * match_reward) - (num_mismatches * mismatch_penalty) - (num_gaps * gap_open_penalty) - (gap_extensions * gap_extension_penalty) - # Calculate the approximate bit score - bit_score = (lambda_value * S - math.log(K_value)) / math.log(2) + # Use precomputed factors to calculate the approximate bit score + bit_score = precomputed_factor * S - precomputed_log_K return bit_score @@ -460,17 +566,15 @@ def get_bam_data( lambda_value=1.02, K_value=0.21 ): + # Precompute factors for the score calculation to avoid redundant computation + precomputed_factor = lambda_value * match_reward / math.log(2) + precomputed_log_K = math.log(K_value) / math.log(2) + bam, references = parms dt.options.progress.enabled = False dt.options.progress.clear_on_success = True - if threads > 1: - dt.options.nthreads = threads - 1 - else: - dt.options.nthreads = 1 - if threads > 4: - s_threads = 4 - else: - s_threads = threads + dt.options.nthreads = max(1, threads - 1) + s_threads = min(4, threads) samfile = pysam.AlignmentFile(bam, "rb", threads=s_threads) results = [] @@ -479,58 +583,36 @@ def get_bam_data( empty_df = 0 for reference in references: - if ref_lengths is None: - reference_length = int(samfile.get_reference_length(reference)) - else: - reference_length = int(ref_lengths[reference]) + reference_length = int(samfile.get_reference_length(reference)) if ref_lengths is None else int(ref_lengths[reference]) aln_data = [] - for aln in samfile.fetch( - contig=reference, multiple_iterators=False, until_eof=True - ): - # AS_value = aln.get_tag("AS") if aln.has_tag("AS") else None - # XS_value = aln.get_tag("XS") if aln.has_tag("XS") else None - query_length = ( - aln.query_length if aln.query_length != 0 else aln.infer_query_length() - ) - - if query_length < min_read_length: - continue - - pident = (1 - ((aln.get_tag("NM") / query_length))) * 100 - num_mismatches = aln.get_tag("NM") - num_matches = query_length - num_mismatches - num_gaps = aln.get_tag("XO") if aln.has_tag("XO") else 0 - gap_extensions = aln.get_tag("XG") if aln.has_tag("XG") else 0 - if pident >= percid: - var = calculate_alignment_score(num_matches=num_matches, num_mismatches=num_mismatches, num_gaps=num_gaps, gap_extensions=gap_extensions, match_reward=match_reward, mismatch_penalty=mismatch_penalty, gap_open_penalty=gap_open_penalty, gap_extension_penalty=gap_extension_penalty, lambda_value=lambda_value, K_value=K_value) - aln_data.append( - ( - aln.query_name, - aln.reference_name, - var, - reference_length, - ) - ) - reads.add(aln.query_name) - refs.add(aln.reference_name) - # remove duplicates - # check if aln_data is not empty - if len(aln_data) > 0: - aln_data_dt = dt.Frame(aln_data) - # "queryId", "subjectId", "bitScore", "slen" - aln_data_dt.names = ["queryId", "subjectId", "var", "slen"] - # remove duplicates and keep the ones with the largest mapq + for aln in samfile.fetch(contig=reference, multiple_iterators=False, until_eof=True): + query_length = aln.query_length if aln.query_length != 0 else aln.infer_query_length() + + if query_length >= min_read_length: + num_mismatches = aln.get_tag("NM") + pident = (1 - (num_mismatches / query_length)) * 100 + if pident >= percid: + num_matches = query_length - num_mismatches + num_gaps = aln.get_tag("XO") if aln.has_tag("XO") else 0 + gap_extensions = aln.get_tag("XG") if aln.has_tag("XG") else 0 + + bit_score = calculate_alignment_score(num_matches, num_mismatches, num_gaps, gap_extensions, match_reward, mismatch_penalty, gap_open_penalty, gap_extension_penalty, precomputed_factor, precomputed_log_K) + aln_data.append((aln.query_name, aln.reference_name, bit_score, reference_length)) + reads.add(aln.query_name) + refs.add(aln.reference_name) + + if aln_data: + aln_data_dt = dt.Frame(aln_data, names=["queryId", "subjectId", "bitScore", "slen"]) aln_data_dt = aln_data_dt[ - :1, :, dt.by(dt.f.queryId, dt.f.subjectId), dt.sort(-dt.f.var) + :1, :, dt.by(dt.f.queryId, dt.f.subjectId), dt.sort(-dt.f.bitScore) ] - # aln_data_dt["slen"] = dt.Frame([ref_lengths[x] for x in aln_data_dt[:,"subjectId"].to_list()[0]]) results.append(aln_data_dt) else: - results.append(dt.Frame()) empty_df += 1 - samfile.close() - return (dt.rbind(results), reads, refs, empty_df) + samfile.close() + combined_results = dt.rbind(results) + return (combined_results, reads, refs, empty_df) def reassign_reads( bam, @@ -608,9 +690,9 @@ def reassign_reads( log.info("::: Creating reference chunks with uniform read amounts...") # ify the number of chunks ref_chunks = sort_keys_by_approx_weight( - input_dict=references_m, scale=1, num_cores=threads + input_dict=references_m, scale=1, num_cores=threads, verbose=False ) - log.info(f"::: Created {len(ref_chunks):,} chunks") + log.info(f"::: ::: Created {len(ref_chunks):,} chunks") ref_chunks = random.sample(ref_chunks, len(ref_chunks)) dt.options.progress.enabled = False @@ -768,6 +850,37 @@ def reassign_reads( n_alns_0 = 0 # Loop through each DataFrame in the list and update it with the joined version log.info("::: Combining data...") + # for i, x in tqdm.tqdm( + # enumerate(data), + # total=len(data), + # desc="Processing batches", + # unit="batch", + # disable=is_debug(), + # leave=False, + # ncols=80, + # ): + # # Perform join with reads and then refs + # x = x[:, :, dt.join(reads)] + # x = x[:, :, dt.join(refs)] + # n_alns_0 += x.shape[0] + # del x["queryId"] + # del x["subjectId"] + # x = x[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy() + # # Substitute the original DataFrame with the joined version in the list + # data[i] = x + + + # Calculate the total number of rows in advance + total_rows = sum(x.shape[0] for x in data) # This assumes `data` is a list of DataFrames/NumPy arrays + + # Assuming all `x` arrays have the same number of columns after processing, use the first one to determine this + # IMPORTANT: This line needs to be executed before the loop and assumes all `x` arrays are similar after processing + num_columns = 4 # Adjust based on your actual data structure + + # Preallocate the NumPy array + mat = np.empty((total_rows, num_columns), dtype=np.float32) # Adjust dtype as necessary + + current_index = 0 for i, x in tqdm.tqdm( enumerate(data), total=len(data), @@ -778,56 +891,102 @@ def reassign_reads( ncols=80, ): # Perform join with reads and then refs - x = x[:, :, dt.join(reads)] - x = x[:, :, dt.join(refs)] - n_alns_0 += x.shape[0] - del x["queryId"] - del x["subjectId"] - x = x[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy() - # Substitute the original DataFrame with the joined version in the list - data[i] = x - - # After the loop, use dt.rbind() to combine all the DataFrames in the list - # data = dt.rbind([x for x in data]) - - # del data["queryId"] - # del data["subjectId"] - # n_alns_0 = data.shape[0] + if x.shape[0] > 0: + x = x[:, :, dt.join(reads)] + x = x[:, :, dt.join(refs)] + n_alns_0 += x.shape[0] + + # Process `x` as before, but directly update `mat` + x_processed = x[:, [dt.f.qidx, dt.f.sidx, dt.f.bitScore, dt.f.slen]].to_numpy() + num_rows = x_processed.shape[0] + + # Fill the preallocated array + mat[current_index:current_index + num_rows, :] = x_processed + + # Update the current index + current_index += num_rows + + # After the loop, `mat` is already the concatenated array, so there's no need for further concatenation or conversion. + data = None # Free the memory if `data` is no longer needed + + # Log the final stats n_reads_0 = reads.shape[0] n_refs_0 = refs.shape[0] log.info( f"::: References: {n_refs_0:,} | Reads: {n_reads_0:,} | Alignments: {n_alns_0:,}" ) - log.info("::: Allocating data structures...") - # mat = data[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy() - mat = np.vstack(data) - data = None - - # Create a zeros array with the same number of rows as 'mat' - zeros_array = np.zeros((mat.shape[0], 5)) - - # Stack the zeros_array with the original 'mat' - m = np.column_stack([mat, zeros_array]) - zeros_array = None - - dtype = np.dtype( - [ - ("source", "int"), - ("subject", "int"), - ("var", "float"), - ("slen", "int"), - ("s_W", "float"), - ("prob", "float"), - ("iter", "int"), - ("n_aln", "int"), - ("max_prob", "float"), - ] - ) - # Convert the unstructured array to structured array - m = rf.unstructured_to_structured(m, dtype) + + # After the loop, use dt.rbind() to combine all the DataFrames in the list + # data = dt.rbind([x for x in data]) + + # del data["queryId"] + # del data["subjectId"] + # n_alns_0 = data.shape[0] + # n_reads_0 = reads.shape[0] + # n_refs_0 = refs.shape[0] + # log.info( + # f"::: References: {n_refs_0:,} | Reads: {n_reads_0:,} | Alignments: {n_alns_0:,}" + # ) + + # log.info("::: Allocating data...") + # # # mat = data[:, [dt.f.qidx, dt.f.sidx, dt.f.var, dt.f.slen]].to_numpy() + # # mat = np.vstack(data) + # # data = None + + # # Create a zeros array with the same number of rows as 'mat' + # zeros_array = np.zeros((mat.shape[0], 5)) + + # # Stack the zeros_array with the original 'mat' + # m = np.column_stack([mat, zeros_array]) + # zeros_array = None + + # dtype = np.dtype( + # [ + # ("source", "int"), + # ("subject", "int"), + # ("var", "float"), + # ("slen", "int"), + # ("s_W", "float"), + # ("prob", "float"), + # ("iter", "int"), + # ("n_aln", "int"), + # ("max_prob", "float"), + # ] + # ) + + # # Convert the unstructured array to structured array + # m = rf.unstructured_to_structured(m, dtype) + # gc.collect() + + log.info("::: Allocating data...") + + + # Define the dtype for the structured array + dtype = np.dtype([ + ("source", "int"), + ("subject", "int"), + ("var", "float"), + ("slen", "int"), + ("s_W", "float"), # This and following fields are initialized to 0 or a default value + ("prob", "float"), + ("iter", "int"), + ("n_aln", "int"), + ("max_prob", "float"), + ]) + + # Initialize the structured array with zeros directly + m = np.zeros(mat.shape[0], dtype=dtype) + m['source'] = mat[:, 0] + m['subject'] = mat[:, 1] + m['var'] = mat[:, 2] + m['slen'] = mat[:, 3] + + + # Force a garbage collection to free up memory from any intermediate arrays that are no longer needed gc.collect() + log.info("::: Initializing data structures...") init_data = initialize_subject_weights(m) if reassign_iters > 0: diff --git a/bam_filter/sam_utils.py b/bam_filter/sam_utils.py index f2dafea..7603a53 100644 --- a/bam_filter/sam_utils.py +++ b/bam_filter/sam_utils.py @@ -56,7 +56,7 @@ def write_bam(bam, references, output_files, threads=1, sort_memory="1G"): else: s_threads = threads samfile = pysam.AlignmentFile(bam, "rb", threads=s_threads) - + header = samfile.header if threads > 4: threads = 4 @@ -91,6 +91,7 @@ def write_bam(bam, references, output_files, threads=1, sort_memory="1G"): referencenames=list(ref_names), referencelengths=ref_lengths, threads=write_threads, + header=header, ) references = [x for x in samfile.references if x in refs_idx.keys()] diff --git a/bam_filter/utils.py b/bam_filter/utils.py index 05eabb6..336b3ea 100644 --- a/bam_filter/utils.py +++ b/bam_filter/utils.py @@ -84,17 +84,19 @@ def refine_chunks(chunks, input_dict, target_weight): def sort_keys_by_approx_weight( - input_dict, scale=1, num_cores=1, refinement_steps=10, verbose=False + input_dict, scale=1, num_cores=1, refinement_steps=10,verbose=False ): if scale == 0: raise ValueError("Scale cannot be zero.") # Calculate the target weight for each chunk - target_weight = int(scale * max(input_dict.values())) + target_weight = scale * int(max(input_dict.values())) # Determine the initial number of chunks based on the number of cores - num_chunks = num_cores - + #num_chunks = num_cores * scale + num_chunks = (((sum(input_dict.values()) // target_weight)) // num_cores) + 1 + if num_chunks < num_cores: + num_chunks = num_cores # Sort keys by their weights in descending order sorted_keys = sorted(input_dict, key=lambda k: input_dict[k], reverse=True)