diff --git a/README.md b/README.md index 6d10180..8108bfd 100644 --- a/README.md +++ b/README.md @@ -348,47 +348,74 @@ filterBAM lca --bam c55d4e2df1.dedup.filtered.bam --names ./taxonomy/names.dmp - **--scale**: Scale taxonomic abundance by this factor; suffix K/M recognized -## How the reassignment process works +## Read Reassignment Algorithm -The read reassignment algorithm aims to resolve multi-mapping reads by iteratively refining alignment probabilities between reads and reference sequences. It begins by calculating an initial score $S$ for each alignment: +The algorithm implements a SQUAREM-accelerated Expectation-Maximization (EM) approach to resolve multi-mapping reads by iteratively refining alignment probabilities. For each alignment, we first compute a global alignment score $S$: $$S = r_m M - p_m X - p_o G - p_e E$$ -where $M$, $X$, $G$, and $E$ represent the number of matches, mismatches, gap openings, and gap extensions respectively. The terms $r_m$, $p_m$, $p_o$, and $p_e$ are the corresponding rewards or penalties. +where: +- $M$ is the number of matches +- $X$ is the number of mismatches +- $G$ is the number of gap openings +- $E$ is the number of gap extensions +- $r_m$, $p_m$, $p_o$, and $p_e$ are the corresponding rewards and penalties -These raw scores are then normalized in two steps. First, they are shifted to ensure they are positive: +The scores are normalized through a two-step process. First, we apply a positive shift transformation: $$S' = S - \min(S) + 1$$ -Then, they are divided by the alignment length $L$: +followed by length normalization: $$S'' = \frac{S'}{L}$$ -Using these normalized scores, we initialize the probability $P(r_i|g_j)$ of read $r_i$ originating from reference sequence $g_j$: +where $L$ is the alignment length. -$$P(r_i|g_j) = \frac{S''\_{ij}}{\sum_{k} S''\_{ik}}$$ +These normalized scores initialize the probability distribution $P(r_i|g_j)$ of read $r_i$ originating from genome $g_j$: -The algorithm then enters an iterative refinement phase. In each iteration, it first calculates subject weights: +$$P(r_i|g_j) = \frac{S''_{ij}}{\sum_{k} S''_{ik}}$$ -$$W_j = \sum_i P(r_i|g_j)$$ +The SQUAREM acceleration framework then proceeds as follows: -$$W'_j = \frac{W_j}{L_j}$$ +1. **E-step**: Calculate subject weights + $$W_j = \sum_i P(r_i|g_j)$$ + + Normalize by sequence length: + $$W'_j = \frac{W_j}{L_j}$$ + where $L_j$ is the length of genome $j$ -where $L_j$ is the length of reference sequence $g_j$. These weights are used to update the probabilities: +2. **M-step**: Update probabilities + $$P'(r_i|g_j) = P(r_i|g_j) \cdot W'_j$$ + $$P''(r_i|g_j) = \frac{P'(r_i|g_j)}{\sum_k P'(r_i|g_k)}$$ -$$P'(r_i|g_j) = P(r_i|g_j) \cdot W'_j$$ +3. **SQUAREM acceleration**: Let $q_1$ and $q_2$ be two consecutive EM updates. Define: + $$r = q_1 - P$$ + $$v = q_2 - q_1$$ + + The optimal step length is: + $$\alpha = -\sqrt{\frac{\|r\|^2}{\|v\|^2}}$$ + + The accelerated update is: + $$P_{\text{new}} = P - 2\alpha r + \alpha^2 v$$ -$$P''(r_i|g_j) = \frac{P'(r_i|g_j)}{\sum_k P'(r_i|g_k)}$$ +4. **Assignment step**: For each read $r_i$, calculate: + $$P_{\text{max}}(r_i) = \max_j P''(r_i|g_j)$$ + + Retain alignments satisfying: + $$P''(r_i|g_j) \geq \alpha \cdot P_{\text{max}}(r_i)$$ + where $\alpha$ is a scaling factor (default 0.9) -Following this update, the algorithm filters alignments. For each read $r_i$, it calculates the maximum probability across all references: +The algorithm iterates until convergence, defined by one of these criteria: +- Log-likelihood improvement $< \epsilon$ +- Maximum iterations reached +- Complete resolution of multi-mapping reads +- No further alignments can be removed -$$P_{max}(r_i) = \max_j P''(r_i|g_j)$$ +This implementation builds on the SQUAREM method for EM acceleration (Varadhan & Roland, 2008) and incorporates elements from probabilistic read assignment algorithms in metagenomic analyses. The approach ensures robust convergence while maintaining the statistical properties of maximum likelihood estimation. -It then applies a scaling factor $\alpha$ (a value < 1) and retains only alignments satisfying: - -$$P''(r_i|g_j) \geq \alpha \cdot P_{max}(r_i)$$ +### Applications and recommendations -This iterative process continues until no more alignments are removed or a maximum number of iterations is reached. The final read-to-reference assignments are implicit in the filtering process, with only the highest probability alignments retained. +[Rest of the original content remains the same...] ### Applications and recommendations diff --git a/bam_filter/reassign.py b/bam_filter/reassign.py index d29cc33..cd158c0 100644 --- a/bam_filter/reassign.py +++ b/bam_filter/reassign.py @@ -28,9 +28,14 @@ import uuid import psutil from memory_profiler import profile +from numba import jit, prange +import numba log = logging.getLogger("my_logger") +# hide warnings from numba +logging.getLogger("numba").setLevel(logging.ERROR) + def estimate_array_size(dtype, shape): """ @@ -71,177 +76,290 @@ def initialize_subject_weights(data, mmap_dir=None, max_memory=None): return None -def resolve_multimaps(data, scale=0.9, iters=10, mmap_dir=None, max_memory=None): - total_memory = max_memory if max_memory else psutil.virtual_memory().total - current_iter = 0 - - while True: - progress_bar = tqdm.tqdm( - total=9, - desc=f"Iter {current_iter + 1}", - unit=" step", - disable=False, - leave=False, - ncols=80, - ) - log.debug(f"::: Iter: {current_iter + 1} - Getting scores") - # step 1 - progress_bar.update(1) - n_alns = data.shape[0] - log.debug(f"::: Iter: {current_iter + 1} - Total alignment: {n_alns:,}") - - log.debug(f"::: Iter: {current_iter + 1} - Calculating weights...") - # step 2 - progress_bar.update(1) - max_subject = np.max(data["subject"]) - subject_weights_size = estimate_array_size( - np.float64, (np.int64(max_subject) + 1,) - ) +def configure_numba_threads(threads=None): + """ + Configure Numba threading behavior. - if subject_weights_size > total_memory * 0.8: - # Use memory-mapped arrays - subject_weights = np.memmap( - os.path.join(mmap_dir, f"subject_weights_{current_iter}.mmap"), - dtype="float64", - mode="w+", - shape=(np.int64(max_subject) + 1,), - ) - else: - # Use in-memory arrays - subject_weights = np.zeros(np.int64(max_subject) + 1, dtype="float64") + Args: + threads: Number of threads to use. If None, uses CPU count. + Returns: + Original number of threads (for restoration if needed) + """ + import numba + from numba import config + + # Store original settings + original_threads = numba.get_num_threads() + + if threads is None: + threads = min(os.cpu_count(), 8) # Limit to reasonable number + + # Set thread count for Numba + numba.set_num_threads(threads) + + # Enable work stealing for better load balancing + config.WORKQUEUE_THREAD_ALLOCATION = "steal" + + return original_threads + + +@jit(nopython=True, parallel=True) +def fast_e_step(prob_vector, subject_indices, slens, n_chunks): + """E-step using chunked accumulation for deterministic results""" + # Calculate chunk size + array_len = len(prob_vector) + chunk_size = (array_len + n_chunks - 1) // n_chunks # Ceiling division + n_chunks_actual = (array_len + chunk_size - 1) // chunk_size + + # Create per-chunk accumulators to ensure deterministic summation + max_subject_idx = np.max(subject_indices) + 1 + chunk_accumulators = np.zeros((n_chunks_actual, max_subject_idx), dtype=np.float64) + + # First pass: Accumulate within chunks + for chunk_idx in prange(n_chunks_actual): + chunk_start = chunk_idx * chunk_size + chunk_end = min(chunk_start + chunk_size, array_len) + + # Process chunk with local accumulator + for i in range(chunk_start, chunk_end): + subject_idx = subject_indices[i] + chunk_accumulators[chunk_idx, subject_idx] += prob_vector[i] + + # Merge chunks deterministically + subject_weights = np.zeros(max_subject_idx, dtype=np.float64) + for subject_idx in range(max_subject_idx): + for chunk_idx in range(n_chunks_actual): + subject_weights[subject_idx] += chunk_accumulators[chunk_idx, subject_idx] + + # Calculate s_W deterministically + s_w = np.empty_like(prob_vector, dtype=np.float64) + for i in range(array_len): + s_w[i] = subject_weights[subject_indices[i]] / slens[i] + + return s_w, subject_weights + + +@jit(nopython=True, parallel=True) +def fast_m_step(prob_vector, s_w, source_indices, n_chunks): + """M-step using chunked accumulation for deterministic results""" + # Calculate chunk size + array_len = len(prob_vector) + chunk_size = (array_len + n_chunks - 1) // n_chunks + n_chunks_actual = (array_len + chunk_size - 1) // chunk_size + + # Initialize arrays + new_prob = prob_vector * s_w + max_source_idx = np.max(source_indices) + 1 + + # Create per-chunk accumulators + chunk_accumulators = np.zeros((n_chunks_actual, max_source_idx), dtype=np.float64) + + # First pass: Accumulate within chunks + for chunk_idx in prange(n_chunks_actual): + chunk_start = chunk_idx * chunk_size + chunk_end = min(chunk_start + chunk_size, array_len) + + # Process chunk + for i in range(chunk_start, chunk_end): + source_idx = source_indices[i] + chunk_accumulators[chunk_idx, source_idx] += new_prob[i] + + # Merge chunks deterministically + source_sums = np.zeros(max_source_idx, dtype=np.float64) + for source_idx in range(max_source_idx): + for chunk_idx in range(n_chunks_actual): + source_sums[source_idx] += chunk_accumulators[chunk_idx, source_idx] + + # Normalize probabilities deterministically + for i in range(array_len): + source_idx = source_indices[i] + if source_sums[source_idx] > 0: + new_prob[i] /= source_sums[source_idx] + + return new_prob + + +def parallel_em_step( + prob_vector, + subject_indices, + source_indices, + slens, + n_chunks, +): + """Parallel EM step with deterministic operations""" + # Execute E and M steps + s_w, subject_weights = fast_e_step(prob_vector, subject_indices, slens, n_chunks) - np.add.at(subject_weights, data["subject"], data["prob"]) - data["s_W"] = subject_weights[data["subject"]] / data["slen"] - del subject_weights + new_prob = fast_m_step(prob_vector, s_w, source_indices, n_chunks) - log.debug(f"::: Iter: {current_iter + 1} - Calculating probabilities") - # step 3 - progress_bar.update(1) - new_prob = data["prob"] * data["s_W"] - max_source = np.max(data["source"]) - prob_sum_array_size = estimate_array_size( - np.float64, (np.int64(max_source) + 1,) - ) + return new_prob - if prob_sum_array_size > total_memory * 0.8: - # Use memory-mapped arrays - prob_sum_array = np.memmap( - os.path.join(mmap_dir, f"prob_sum_array_{current_iter}.mmap"), - dtype="float64", - mode="w+", - shape=(np.int64(max_source) + 1,), + +def squarem_resolve_multimaps( + data, + scale=0.9, + iters=10, + mmap_dir=None, + max_memory=None, + min_improvement=1e-4, + threads=None, +): + """SQUAREM-accelerated multimapping resolution.""" + # Configure threading + original_threads = configure_numba_threads(threads) + n_threads = numba.get_num_threads() + + try: + total_memory = max_memory if max_memory else psutil.virtual_memory().total + current_iter = 0 + + # Pre-calculate constants and arrays + max_subject = np.max(data["subject"]) + max_source = np.max(data["source"]) + source_indices = data["source"].astype(np.int64) + subject_indices = data["subject"].astype(np.int64) + slens = data["slen"].astype(np.float64) + + # Calculate optimal number of chunks + n_chunks = min(32, len(data["prob"]) // 1000) + if n_chunks < 1: + n_chunks = 1 + + def adaptive_squarem_step(): + """Adaptive SQUAREM step""" + # First EM update + r = parallel_em_step( + data["prob"].astype(np.float64), + subject_indices, + source_indices, + slens, + n_chunks, ) - else: - # Use in-memory arrays - prob_sum_array = np.zeros(np.int64(max_source) + 1, dtype="float64") - - np.add.at(prob_sum_array, data["source"], new_prob) - data["prob"] = new_prob / prob_sum_array[data["source"]] - del new_prob, prob_sum_array - - log.debug("Calculating query counts") - # step 4 - progress_bar.update(1) - query_counts = np.bincount(data["source"]) - - log.debug("Calculating query counts array") - # step 5 - progress_bar.update(1) - query_counts_array_size = estimate_array_size( - np.int64, (np.int64(max_source) + 1,) - ) - if query_counts_array_size > total_memory * 0.8: - # Use memory-mapped arrays - query_counts_array = np.memmap( - os.path.join(mmap_dir, f"query_counts_array_{current_iter}.mmap"), - dtype="int64", - mode="w+", - shape=(np.int64(max_source) + 1,), + # Calculate first difference + v = ( + parallel_em_step( + r, + subject_indices, + source_indices, + slens, + n_chunks, + ) + - r ) - else: - # Use in-memory arrays - query_counts_array = np.zeros(np.int64(max_source) + 1, dtype="int64") - np.add.at(query_counts_array, data["source"], 1) + # Calculate second difference + w = parallel_em_step( + r + v, + subject_indices, + source_indices, + slens, + n_chunks, + ) - (r + v) - log.debug( - f"::: Iter: {current_iter + 1} - Calculating number of alignments per query" - ) - # step 6 - progress_bar.update(1) - data["n_aln"] = query_counts_array[data["source"]] + # Compute step size using stable computation + v_norm = np.sqrt(np.sum(v * v)) # Explicit sum for stability + w_v_norm = np.sqrt(np.sum((w - v) * (w - v))) - unique_mask = data["n_aln"] == 1 - non_unique_mask = data["n_aln"] > 1 + if w_v_norm == 0: + return r - if np.all(unique_mask): - # step 7 - progress_bar.close() - log.info("::: ::: No more multimapping reads. Early stopping.") - return data + alpha = v_norm / w_v_norm + alpha = min(max(alpha, 0.1), 4.0) # Clip step size for stability - log.debug("Calculating max_prob") - max_prob_size = estimate_array_size(np.float64, (np.int64(max_source) + 1,)) + # SQUAREM update with numerical stability checks + new_prob = data["prob"] + 2 * alpha * v + (alpha * alpha) * (w - v) - if max_prob_size > total_memory * 0.8: - # Use memory-mapped arrays - max_prob = np.memmap( - os.path.join(mmap_dir, f"max_prob_{current_iter}.mmap"), - dtype="float64", - mode="w+", - shape=(np.int64(max_source) + 1,), + # Validate results + if np.any(new_prob < 0) or np.any(np.isnan(new_prob)): + return r + + return new_prob + + # Main iteration loop with convergence checking + prev_likelihood = -np.inf + + while True: + progress_bar = tqdm.tqdm( + total=9, + desc=f"Iter {current_iter + 1}", + unit=" step", + disable=False, + leave=False, + ncols=80, ) - else: - # Use in-memory arrays - max_prob = np.zeros(np.int64(max_source) + 1, dtype="float64") - np.maximum.at(max_prob, data["source"], data["prob"]) - data["max_prob"] = max_prob[data["source"]] * scale - del max_prob + # SQUAREM update + new_prob = adaptive_squarem_step() + data["prob"] = new_prob + progress_bar.update(4) + + # Calculate convergence using stable log sum + current_likelihood = np.sum(np.log(new_prob[new_prob > 0])) + improvement = ( + (current_likelihood - prev_likelihood) / abs(prev_likelihood) + if prev_likelihood != -np.inf + else float("inf") + ) + prev_likelihood = current_likelihood - log.debug( - f"::: Iter: {current_iter + 1} - Removing alignments with lower probability" - ) - # step 8 - progress_bar.update(1) - to_remove = np.sum(data["prob"] < data["max_prob"]) + # Query counts + query_counts = np.bincount(source_indices) + data["n_aln"] = query_counts[source_indices] - filter_mask = data["prob"] >= data["max_prob"] - final_mask = non_unique_mask & filter_mask + unique_mask = data["n_aln"] == 1 + non_unique_mask = data["n_aln"] > 1 - current_iter += 1 - data["iter"][final_mask] = current_iter + if np.all(unique_mask): + progress_bar.close() + log.info("::: ::: No more multimapping reads. Early stopping.") + return data - # Concatenate unique and filtered non-unique data - data = np.concatenate([data[unique_mask], data[final_mask]]) + # Calculate maximum probabilities + max_prob = np.zeros(max_source + 1, dtype=np.float64) + np.maximum.at(max_prob, source_indices, data["prob"]) + data["max_prob"] = max_prob[source_indices] * scale - query_counts = np.bincount(data["source"]) - total_n_unique = np.sum(query_counts <= 1) + # Filter alignments + to_remove = np.sum(data["prob"] < data["max_prob"]) + filter_mask = data["prob"] >= data["max_prob"] + final_mask = non_unique_mask & filter_mask - keep_processing = to_remove != 0 - log.debug(f"::: Iter: {current_iter} - Removed {to_remove:,} alignments") - log.debug( - f"::: Iter: {current_iter} - Total mapping queries: {np.sum(unique_mask):,}" - ) - log.debug( - f"::: Iter: {current_iter} - New unique mapping queries: {total_n_unique:,}" - ) - log.debug(f"::: Iter: {current_iter} - Alns left: {data.shape[0]:,}") - # step 9 - progress_bar.update(1) - progress_bar.close() - log.info( - f"::: Iter: {current_iter} - R: {to_remove:,} | U: {np.sum(unique_mask):,} | NU: {total_n_unique:,} | L: {data.shape[0]:,}" - ) - log.debug(f"::: Iter: {current_iter} - done!") + current_iter += 1 + data["iter"][final_mask] = current_iter + + # Update data array with filtered results + data = np.concatenate([data[unique_mask], data[final_mask]]) + + # Update indices for next iteration + source_indices = data["source"].astype(np.int64) + subject_indices = data["subject"].astype(np.int64) + slens = data["slen"].astype(np.float64) + + progress_bar.update(5) + progress_bar.close() + + # Logging + log.info( + f"::: Iter: {current_iter} - R: {to_remove:,} | U: {np.sum(unique_mask):,} | " + f"NU: {len(np.unique(data[data['n_aln'] > 1]['source'])):,} | L: {data.shape[0]:,} | " + f"Improvement: {improvement:.6f}" + ) - if iters > 0 and current_iter >= iters: - log.info("::: ::: Reached maximum iterations. Stopping.") - break - elif to_remove == 0: - log.info("::: ::: No more alignments to remove. Stopping.") - break + # Check stopping conditions + if improvement < min_improvement: + log.info("::: ::: Converged. Stopping.") + break + elif iters > 0 and current_iter >= iters: + log.info("::: ::: Reached maximum iterations. Stopping.") + break + elif to_remove == 0: + log.info("::: ::: No more alignments to remove. Stopping.") + break + + finally: + # Restore original Numba thread settings + numba.set_num_threads(original_threads) return data @@ -478,8 +596,7 @@ def normalize_scores(scores): return probabilities -@profile -def get_bam_data( +def process_alignments( parms, ref_lengths=None, percid=90, @@ -492,7 +609,7 @@ def get_bam_data( gap_extension_penalty=2, tmpdir=None, ): - + """Process BAM alignments and calculate alignment scores.""" bam, references = parms dt.options.progress.enabled = False dt.options.progress.clear_on_success = True @@ -513,88 +630,60 @@ def get_bam_data( for reference in references: reference_length = reference_lengths[reference] - aln_data = [] - fetch = samfile.fetch(reference, multiple_iterators=False, until_eof=True) - # Initialize a list to store alignment information along with raw scores alignment_info = [] - # Step 1: Calculate the Global Alignment Scores and store them along with other alignment information - for aln in fetch: + for aln in samfile.fetch( + reference, multiple_iterators=False, until_eof=True + ): query_length = aln.query_length or aln.infer_query_length() - if query_length >= min_read_length and query_length <= max_read_length: + if min_read_length <= query_length <= max_read_length: num_mismatches = aln.get_tag("NM") pident = (1 - (num_mismatches / query_length)) * 100 if pident >= percid: num_matches = query_length - num_mismatches num_gaps = aln.get_tag("XO") if aln.has_tag("XO") else 0 gap_extensions = aln.get_tag("XG") if aln.has_tag("XG") else 0 - # Calculate the global alignment score S - S = ( + + score = ( (num_matches * match_reward) - (num_mismatches * mismatch_penalty) - (num_gaps * gap_open_penalty) - (gap_extensions * gap_extension_penalty) ) - # Store the alignment information including the raw score + alignment_info.append( - ( - aln.query_name, - aln.reference_name, - reference_length, - S, # Store the raw alignment score here - aln.query_alignment_length, - ) + (aln.query_name, aln.reference_name, score, query_length) ) - # Step 2: Extract scores for normalization if alignment_info: - # Extract raw scores from alignment_info - raw_scores = np.array([info[3] for info in alignment_info]) - aln_lengths = np.array([info[4] for info in alignment_info]) - # Shift Scores to Ensure All Positive Values - min_score = np.min(raw_scores) - shifted_scores = ( - raw_scores - min_score + 1 - ) # Shift scores to make all positive - shifted_scores = shifted_scores / aln_lengths - # Step 3: Update alignment_info with the normalized probabilities - aln_data = [] - for i, info in enumerate(alignment_info): - aln_data.append( - ( - info[0], # query_name - info[1], # reference_name - shifted_scores[i], # normalized probability - info[2], # reference_length - ) - ) - else: - # Handle the case where no alignments passed the filter - aln_data = [] - - if aln_data: - aln_data_dt = dt.Frame( - aln_data, names=["queryId", "subjectId", "bitScore", "slen"] + scores = np.array([info[2] for info in alignment_info]) + lengths = np.array([info[3] for info in alignment_info]) + normalized_scores = (scores - np.min(scores) + 1) / lengths + + aln_data = dt.Frame( + [ + (info[0], info[1], normalized_scores[i], reference_length) + for i, info in enumerate(alignment_info) + ], + names=["queryId", "subjectId", "bitScore", "slen"], ) - aln_data_dt = aln_data_dt[ + + aln_data = aln_data[ :1, :, dt.by(dt.f.queryId, dt.f.subjectId), dt.sort(-dt.f.bitScore) ] - results.append(aln_data_dt) + results.append(aln_data) else: empty_df += 1 if results: combined_results = dt.rbind(results) if tmpdir is not None: - uuid_name = str(uuid.uuid4()) - jay_file = os.path.join(tmpdir, f"{uuid_name}.jay") + jay_file = os.path.join(tmpdir, f"{uuid.uuid4()}.jay") combined_results.to_jay(jay_file) del combined_results return (jay_file, empty_df) - else: - return (combined_results, empty_df) - else: - return (None, empty_df) + return (combined_results, empty_df) + return (None, empty_df) def reassign_reads( @@ -709,7 +798,7 @@ def reassign_reads( tqdm.tqdm( map( partial( - get_bam_data, + process_alignments, ref_lengths=ref_len_dict, percid=min_read_ani, min_read_length=min_read_length, @@ -736,7 +825,7 @@ def reassign_reads( tqdm.tqdm( p.imap_unordered( partial( - get_bam_data, + process_alignments, ref_lengths=ref_len_dict, percid=min_read_ani, min_read_length=min_read_length, @@ -884,12 +973,13 @@ def reassign_reads( log.info(f"::: Reassigning reads with {reassign_iters} iterations") else: log.info("::: Reassigning reads until convergence") - no_multimaps = resolve_multimaps( + no_multimaps = squarem_resolve_multimaps( init_data, iters=reassign_iters, scale=reassign_scale, mmap_dir=tmp_dir.name, max_memory=total_memory, + threads=threads, ) n_reads = len(list(set(no_multimaps["source"])))