diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index ffc6296..eac3c2f 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -18,7 +18,7 @@ jobs: # Define a matrix of Python versions to test against strategy: matrix: - python-version: [3.11, 3.12] + python-version: [3.11, 3.12, 3.13] steps: # Step 1: Checkout the repository diff --git a/src/snipe/api/multisig_reference_QC.py b/src/snipe/api/multisig_reference_QC.py new file mode 100644 index 0000000..2edcb0f --- /dev/null +++ b/src/snipe/api/multisig_reference_QC.py @@ -0,0 +1,901 @@ +import heapq +import logging +import warnings +from typing import Any, Dict, Iterator, List, Optional, Union + +import numpy as np +from scipy.optimize import OptimizeWarning, curve_fit +from snipe.api.snipe_sig import SnipeSig +from snipe.api.enums import SigType +import os +import requests +from tqdm import tqdm +from urllib.parse import urlparse +from typing import Optional +import sourmash +from snipe.api.reference_QC import ReferenceQC +import concurrent + +# pylint disable C0301 + +class MultiSigReferenceQC: + r""" + Class for performing quality control of sequencing data against a reference genome. + + This class computes various metrics to assess the quality and characteristics of a sequencing sample, including coverage indices and abundance ratios, by comparing sample k-mer signatures with a reference genome and an optional amplicon signature. + + **Parameters** + + - `sample_sig` (`SnipeSig`): The sample k-mer signature (must be of type `SigType.SAMPLE`). + - `reference_sig` (`SnipeSig`): The reference genome k-mer signature (must be of type `SigType.GENOME`). + - `amplicon_sig` (`Optional[SnipeSig]`): The amplicon k-mer signature (must be of type `SigType.AMPLICON`), if applicable. + - `enable_logging` (`bool`): Flag to enable detailed logging. + + **Attributes** + + - `sample_sig` (`SnipeSig`): The sample signature. + - `reference_sig` (`SnipeSig`): The reference genome signature. + - `amplicon_sig` (`Optional[SnipeSig]`): The amplicon signature. + - `sample_stats` (`Dict[str, Any]`): Statistics of the sample signature. + - `genome_stats` (`Dict[str, Any]`): Calculated genome-related statistics. + - `amplicon_stats` (`Dict[str, Any]`): Calculated amplicon-related statistics (if `amplicon_sig` is provided). + - `advanced_stats` (`Dict[str, Any]`): Calculated advanced statistics (optional). + - `predicted_assay_type` (`str`): Predicted assay type based on metrics. + + **Calculated Metrics** + + The class calculates the following metrics: + + - **Total unique k-mers** + - Description: Number of unique k-mers in the sample signature. + - Calculation: + $$ + \text{Total unique k-mers} = \left| \text{Sample k-mer set} \right| + $$ + + - **k-mer total abundance** + - Description: Sum of abundances of all k-mers in the sample signature. + - Calculation: + $$ + \text{k-mer total abundance} = \sum_{k \in \text{Sample k-mer set}} \text{abundance}(k) + $$ + + - **k-mer mean abundance** + - Description: Average abundance of k-mers in the sample signature. + - Calculation: + $$ + \text{k-mer mean abundance} = \frac{\text{k-mer total abundance}}{\text{Total unique k-mers}} + $$ + + - **k-mer median abundance** + - Description: Median abundance of k-mers in the sample signature. + - Calculation: Median of abundances in the sample k-mers. + + - **Number of singletons** + - Description: Number of k-mers with an abundance of 1 in the sample signature. + - Calculation: + $$ + \text{Number of singletons} = \left| \{ k \in \text{Sample k-mer set} \mid \text{abundance}(k) = 1 \} \right| + $$ + + - **Genomic unique k-mers** + - Description: Number of k-mers shared between the sample and the reference genome. + - Calculation: + $$ + \text{Genomic unique k-mers} = \left| \text{Sample k-mer set} \cap \text{Reference genome k-mer set} \right| + $$ + + - **Genome coverage index** + - Description: Proportion of the reference genome's k-mers that are present in the sample. + - Calculation: + $$ + \text{Genome coverage index} = \frac{\text{Genomic unique k-mers}}{\left| \text{Reference genome k-mer set} \right|} + $$ + + - **Genomic k-mers total abundance** + - Description: Sum of abundances for k-mers shared with the reference genome. + - Calculation: + $$ + \text{Genomic k-mers total abundance} = \sum_{k \in \text{Sample k-mer set} \cap \text{Reference genome k-mer set}} \text{abundance}(k) + $$ + + - **Genomic k-mers mean abundance** + - Description: Average abundance of k-mers shared with the reference genome. + - Calculation: + $$ + \text{Genomic k-mers mean abundance} = \frac{\text{Genomic k-mers total abundance}}{\text{Genomic unique k-mers}} + $$ + + - **Mapping index** + - Description: Proportion of the sample's total k-mer abundance that maps to the reference genome. + - Calculation: + $$ + \text{Mapping index} = \frac{\text{Genomic k-mers total abundance}}{\text{k-mer total abundance}} + $$ + + If `amplicon_sig` is provided, additional metrics are calculated: + + - **Amplicon unique k-mers** + - Description: Number of k-mers shared between the sample and the amplicon. + - Calculation: + $$ + \text{Amplicon unique k-mers} = \left| \text{Sample k-mer set} \cap \text{Amplicon k-mer set} \right| + $$ + + - **Amplicon coverage index** + - Description: Proportion of the amplicon's k-mers that are present in the sample. + - Calculation: + $$ + \text{Amplicon coverage index} = \frac{\text{Amplicon unique k-mers}}{\left| \text{Amplicon k-mer set} \right|} + $$ + + - **Amplicon k-mers total abundance** + - Description: Sum of abundances for k-mers shared with the amplicon. + - Calculation: + $$ + \text{Amplicon k-mers total abundance} = \sum_{k \in \text{Sample k-mer set} \cap \text{Amplicon k-mer set}} \text{abundance}(k) + $$ + + - **Amplicon k-mers mean abundance** + - Description: Average abundance of k-mers shared with the amplicon. + - Calculation: + $$ + \text{Amplicon k-mers mean abundance} = \frac{\text{Amplicon k-mers total abundance}}{\text{Amplicon unique k-mers}} + $$ + + - **Relative total abundance** + - Description: Ratio of the amplicon k-mers total abundance to the genomic k-mers total abundance. + - Calculation: + $$ + \text{Relative total abundance} = \frac{\text{Amplicon k-mers total abundance}}{\text{Genomic k-mers total abundance}} + $$ + + - **Relative coverage** + - Description: Ratio of the amplicon coverage index to the genome coverage index. + - Calculation: + $$ + \text{Relative coverage} = \frac{\text{Amplicon coverage index}}{\text{Genome coverage index}} + $$ + + - **Predicted Assay Type** + - Description: Predicted assay type based on the `Relative total abundance`. + - Calculation: + - If \(\text{Relative total abundance} \leq 0.0809\), then **WGS** (Whole Genome Sequencing). + - If \(\text{Relative total abundance} \geq 0.1188\), then **WXS** (Whole Exome Sequencing). + - If between these values, assign based on the closest threshold. + + **Advanced Metrics** (optional, calculated if `include_advanced` is `True`): + + - **Median-trimmed unique k-mers** + - Description: Number of unique k-mers in the sample after removing k-mers with abundance below the median. + - Calculation: + - Remove k-mers where \(\text{abundance}(k) < \text{Median abundance}\). + - Count the remaining k-mers. + + - **Median-trimmed total abundance** + - Description: Sum of abundances after median trimming. + - Calculation: + $$ + \text{Median-trimmed total abundance} = \sum_{k \in \text{Median-trimmed Sample k-mer set}} \text{abundance}(k) + $$ + + - **Median-trimmed mean abundance** + - Description: Average abundance after median trimming. + - Calculation: + $$ + \text{Median-trimmed mean abundance} = \frac{\text{Median-trimmed total abundance}}{\text{Median-trimmed unique k-mers}} + $$ + + - **Median-trimmed median abundance** + - Description: Median abundance after median trimming. + - Calculation: Median of abundances in the median-trimmed sample. + + - **Median-trimmed Genomic unique k-mers** + - Description: Number of genomic k-mers in the median-trimmed sample. + - Calculation: + $$ + \text{Median-trimmed Genomic unique k-mers} = \left| \text{Median-trimmed Sample k-mer set} \cap \text{Reference genome k-mer set} \right| + $$ + + - **Median-trimmed Genome coverage index** + - Description: Genome coverage index after median trimming. + - Calculation: + $$ + \text{Median-trimmed Genome coverage index} = \frac{\text{Median-trimmed Genomic unique k-mers}}{\left| \text{Reference genome k-mer set} \right|} + $$ + + - **Median-trimmed Amplicon unique k-mers** (if `amplicon_sig` is provided) + - Description: Number of amplicon k-mers in the median-trimmed sample. + - Calculation: + $$ + \text{Median-trimmed Amplicon unique k-mers} = \left| \text{Median-trimmed Sample k-mer set} \cap \text{Amplicon k-mer set} \right| + $$ + + - **Median-trimmed Amplicon coverage index** + - Description: Amplicon coverage index after median trimming. + - Calculation: + $$ + \text{Median-trimmed Amplicon coverage index} = \frac{\text{Median-trimmed Amplicon unique k-mers}}{\left| \text{Amplicon k-mer set} \right|} + $$ + + - **Median-trimmed relative coverage** + - Description: Relative coverage after median trimming. + - Calculation: + $$ + \text{Median-trimmed relative coverage} = \frac{\text{Median-trimmed Amplicon coverage index}}{\text{Median-trimmed Genome coverage index}} + $$ + + - **Median-trimmed relative mean abundance** + - Description: Ratio of median-trimmed amplicon mean abundance to median-trimmed genomic mean abundance. + - Calculation: + $$ + \text{Median-trimmed relative mean abundance} = \frac{\text{Median-trimmed Amplicon mean abundance}}{\text{Median-trimmed Genomic mean abundance}} + $$ + + **Usage Example** + + ```python + qc = ReferenceQC( + sample_sig=sample_signature, + reference_sig=reference_signature, + amplicon_sig=amplicon_signature, + enable_logging=True + ) + + stats = qc.get_aggregated_stats(include_advanced=True) + ``` + """ + + def __init__(self, *, + reference_sig: SnipeSig, + amplicon_sig: Optional[SnipeSig] = None, + ychr: Optional[SnipeSig] = None, + chr_to_sig: Optional[Dict[str, SnipeSig]] = None, + varsigs: Optional[List[SnipeSig]] = None, + enable_logging: bool = False, + **kwargs): + + # Initialize logger + self.logger = logging.getLogger(self.__class__.__name__) + + # Initialize split cache + self._split_cache: Dict[int, List[SnipeSig]] = {} + self.logger.debug("Initialized split cache.") + + + if enable_logging: + self.logger.setLevel(logging.DEBUG) + if not self.logger.hasHandlers(): + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + ch.setFormatter(formatter) + self.logger.addHandler(ch) + self.logger.debug("Logging is enabled for ReferenceQC.") + else: + self.logger.setLevel(logging.CRITICAL) + + # logging all passed parameters + self.logger.debug("passed parameters:\n") + for key, value in locals().items(): + self.logger.debug("\t%s: %s", key, value) + + + if reference_sig.sigtype != SigType.GENOME: + self.logger.error("Invalid signature type for reference_sig: %s", reference_sig.sigtype) + raise ValueError(f"reference_sig must be of type {SigType.GENOME}, got {reference_sig.sigtype}") + + if amplicon_sig is not None and amplicon_sig.sigtype != SigType.AMPLICON: + self.logger.error("Invalid signature type for amplicon_sig: %s", amplicon_sig.sigtype) + raise ValueError(f"amplicon_sig must be of type {SigType.AMPLICON}, got {amplicon_sig.sigtype}") + + if ychr and not isinstance(ychr, SnipeSig): + self.logger.error("Invalid signature type for ychr: %s", ychr.sigtype) + raise ValueError(f"ychr must be of type {SigType.SAMPLE}, got {ychr.sigtype}") + + self.specific_chr_to_sig: Optional[Dict[str, SnipeSig]] = None + + if ychr is not None and chr_to_sig is not None: + chr_to_sig['sex-y'] = ychr + + if chr_to_sig is not None: + self.logger.debug("Computing specific chromosome hashes for %s.", ','.join(chr_to_sig.keys())) + self.logger.debug(f"\t-All hashes for chromosomes before getting unique sigs {len(SnipeSig.sum_signatures(list(chr_to_sig.values())))}") + self.specific_chr_to_sig = SnipeSig.get_unique_signatures({sig_name: sig for sig_name, sig in chr_to_sig.items() if not sig_name.endswith("-snipegenome")}) + self.logger.debug(f"\t-All hashes for chromosomes after getting unique sigs {len(SnipeSig.sum_signatures(list(self.specific_chr_to_sig.values())))}") + + self.variance_sigs: Optional[List[SnipeSig]] = None + if varsigs is not None: + self.logger.debug("Variance signatures provided.") + # make sure they are same ksize and scale as reference_sig + for sig in varsigs: + if sig.ksize != reference_sig.ksize: + self.logger.error("K-mer sizes do not match: varsigs.ksize=%d vs reference_sig.ksize=%d", + sig.ksize, reference_sig.ksize) + raise ValueError(f"varsigs ksize ({sig.ksize}) does not match reference_sig ksize ({reference_sig.ksize}).") + if sig.scale != reference_sig.scale: + self.logger.error("Scale values do not match: varsigs.scale=%d vs reference_sig.scale=%d", + sig.scale, reference_sig.scale) + raise ValueError(f"varsigs scale ({sig.scale}) does not match reference_sig scale ({reference_sig.scale}).") + self.variance_sigs = varsigs + + self.logger.debug("Chromosome specific signatures provided.") + self.flag_activate_sex_metrics = True + + + self.reference_sig = reference_sig + self.amplicon_sig = amplicon_sig + self.enable_logging = enable_logging + self.sample_to_stats = {} + + + # Set grey zone thresholds + self.relative_total_abundance_grey_zone = [0.08092723407173719, 0.11884490500267662] + + + def process_sample(self, sample_sig: SnipeSig, predict_extra_folds: Optional[List[int]] = None, advanced: Optional[bool] = False) -> Dict[str, Any]: + + # ============= Attributes ============= + + # Initialize attributes + sample_stats: Dict[str, Any] = {} + genome_stats: Dict[str, Any] = {} + amplicon_stats: Dict[str, Any] = {} + advanced_stats: Dict[str, Any] = {} + chrs_stats: Dict[str, Dict[str, Any]] = {} + sex_stats: Dict[str, Any] = {} + predicted_error_contamination_index: Dict[str, Any] = {} + vars_nonref_stats: Dict[str, Any] = {} + chr_to_mean_abundance: Dict[str, float] = {} + predicted_assay_type: str = "WGS" + roi_stats: Dict[str, Any] = {} + + + # ============= SAMPLE Verification ============= + + + self.logger.debug("Validating ksize and scale across signatures.") + if sample_sig.ksize != self.reference_sig.ksize: + self.logger.error("K-mer sizes do not match: sample_sig.ksize=%d vs reference_sig.ksize=%d", + sample_sig.ksize, self.reference_sig.ksize) + raise ValueError(f"sample_sig kszie ({sample_sig.ksize}) does not match reference_sig ksize ({self.reference_sig.ksize}).") + if sample_sig.scale != self.reference_sig.scale: + self.logger.error("Scale values do not match: sample_sig.scale=%d vs reference_sig.scale=%d", + sample_sig.scale, self.reference_sig.scale) + raise ValueError(f"sample_sig scale ({sample_sig.scale}) does not match reference_sig scale ({self.reference_sig.scale}).") + + if self.amplicon_sig is not None: + if self.amplicon_sig.ksize != sample_sig.ksize: + self.logger.error("K-mer sizes do not match: amplicon_sig.ksize=%d vs sample_sig.ksize=%d", + self.amplicon_sig.ksize, sample_sig.ksize) + raise ValueError(f"amplicon_sig ksize ({self.amplicon_sig.ksize}) does not match sample_sig ksize ({sample_sig.ksize}).") + if self.amplicon_sig.scale != sample_sig.scale: + self.logger.error("Scale values do not match: amplicon_sig.scale=%d vs sample_sig.scale=%d", + self.amplicon_sig.scale, sample_sig.scale) + raise ValueError(f"amplicon_sig scale ({self.amplicon_sig.scale}) does not match sample_sig scale ({sample_sig.scale}).") + + self.logger.debug("All signatures have matching ksize and scale.") + + # Verify signature types + if sample_sig._type != SigType.SAMPLE: + self.logger.error("Invalid signature type for sample_sig: %s | %s", sample_sig.sigtype, sample_sig._type) + raise ValueError(f"sample_sig must be of type {SigType.SAMPLE}, got {sample_sig.sigtype}") + + + # ============= SAMPLE STATS ============= + + self.logger.debug("Processing sample statistics.") + sample_stats_raw = sample_sig.get_sample_stats + sample_stats.update({ + "name": sample_stats_raw["name"], + "ksize": sample_stats_raw["ksize"], + "scale": sample_stats_raw["scale"], + "filename": sample_stats_raw["filename"], + "Total unique k-mers": sample_stats_raw["num_hashes"], + "k-mer total abundance": sample_stats_raw["total_abundance"], + "k-mer mean abundance": sample_stats_raw["mean_abundance"], + "k-mer median abundance": sample_stats_raw["median_abundance"], + "num_singletons": sample_stats_raw["num_singletons"], + }) + + # ============= GENOME STATS ============= + + self.logger.debug("Calculating genome statistics.") + # Compute intersection of sample and reference genome + self.logger.debug("Type of sample_sig: %s | Type of reference_sig: %s", sample_sig.sigtype, self.reference_sig.sigtype) + sample_genome = sample_sig & self.reference_sig + # Get stats (call get_sample_stats only once) + + # Log hashes and abundances for both sample and reference + # self.logger.debug("Sample hashes: %s", self.sample_sig.hashes) + # self.logger.debug("Sample abundances: %s", self.sample_sig.abundances) + # self.logger.debug("Reference hashes: %s", self.reference_sig.hashes) + # self.logger.debug("Reference abundances: %s", self.reference_sig.abundances) + + sample_genome_stats = sample_genome.get_sample_stats + + genome_stats.update({ + "Genomic unique k-mers": sample_genome_stats["num_hashes"], + "Genomic k-mers total abundance": sample_genome_stats["total_abundance"], + "Genomic k-mers mean abundance": sample_genome_stats["mean_abundance"], + "Genomic k-mers median abundance": sample_genome_stats["median_abundance"], + # Genome coverage index + "Genome coverage index": ( + sample_genome_stats["num_hashes"] / len(self.reference_sig) + if len(self.reference_sig) > 0 else 0 + ), + "Mapping index": ( + sample_genome_stats["total_abundance"] / sample_stats["k-mer total abundance"] + if sample_stats["k-mer total abundance"] > 0 else 0 + ), + }) + + # ============= AMPLICON STATS ============= + if self.amplicon_sig is not None: + self.logger.debug("Calculating amplicon statistics.") + sample_amplicon = sample_sig & self.amplicon_sig + sample_amplicon_stats = sample_amplicon.get_sample_stats + + amplicon_stats.update({ + "Amplicon unique k-mers": sample_amplicon_stats["num_hashes"], + "Amplicon k-mers total abundance": sample_amplicon_stats["total_abundance"], + "Amplicon k-mers mean abundance": sample_amplicon_stats["mean_abundance"], + "Amplicon k-mers median abundance": sample_amplicon_stats["median_abundance"], + "Amplicon coverage index": ( + sample_amplicon_stats["num_hashes"] / len(self.amplicon_sig) + if len(self.amplicon_sig) > 0 else 0 + ), + }) + + # ============= RELATIVE STATS ============= + amplicon_stats["Relative total abundance"] = ( + amplicon_stats["Amplicon k-mers total abundance"] / genome_stats["Genomic k-mers total abundance"] + if genome_stats["Genomic k-mers total abundance"] > 0 else 0 + ) + amplicon_stats["Relative coverage"] = ( + amplicon_stats["Amplicon coverage index"] / genome_stats["Genome coverage index"] + if genome_stats["Genome coverage index"] > 0 else 0 + ) + + relative_total_abundance = amplicon_stats["Relative total abundance"] + if relative_total_abundance <= self.relative_total_abundance_grey_zone[0]: + predicted_assay_type = "WGS" + elif relative_total_abundance >= self.relative_total_abundance_grey_zone[1]: + predicted_assay_type = "WXS" + else: + # Assign based on the closest threshold + distance_to_wgs = abs(relative_total_abundance - self.relative_total_abundance_grey_zone[0]) + distance_to_wxs = abs(relative_total_abundance - self.relative_total_abundance_grey_zone[1]) + predicted_assay_type = "WGS" if distance_to_wgs < distance_to_wxs else "WXS" + + self.logger.debug("Predicted assay type: %s", predicted_assay_type) + + else: + self.logger.debug("No amplicon signature provided.") + + # ============= Contaminatino/Error STATS ============= + + self.logger.debug("Calculuating error and contamination indices.") + try: + sample_nonref = sample_sig - self.reference_sig + sample_nonref_singletons = sample_nonref.count_singletons() + sample_nonref_non_singletons = sample_nonref.total_abundance - sample_nonref_singletons + sample_total_abundance = sample_sig.total_abundance + + predicted_error_index = sample_nonref_singletons / sample_total_abundance + predicted_contamination_index = sample_nonref_non_singletons / sample_total_abundance + + # predict error and contamination index + predicted_error_contamination_index["Predicted contamination index"] = predicted_contamination_index + predicted_error_contamination_index["Sequencing errors index"] = predicted_error_index + # except zero division error + except ZeroDivisionError: + self.logger.error("Please check the sample signature, it seems to be empty.") + + + # ============= Advanced Stats if needed ============= + if advanced: + # Copy sample signature to avoid modifying the original + median_trimmed_sample_sig = sample_sig.copy() + # Trim below median + median_trimmed_sample_sig.trim_below_median() + # Get stats + median_trimmed_sample_stats = median_trimmed_sample_sig.get_sample_stats + advanced_stats.update({ + "Median-trimmed unique k-mers": median_trimmed_sample_stats["num_hashes"], + "Median-trimmed total abundance": median_trimmed_sample_stats["total_abundance"], + "Median-trimmed mean abundance": median_trimmed_sample_stats["mean_abundance"], + "Median-trimmed median abundance": median_trimmed_sample_stats["median_abundance"], + }) + + # Genome stats for median-trimmed sample + median_trimmed_sample_genome = median_trimmed_sample_sig & self.reference_sig + median_trimmed_sample_genome_stats = median_trimmed_sample_genome.get_sample_stats + advanced_stats.update({ + "Median-trimmed Genomic unique k-mers": median_trimmed_sample_genome_stats["num_hashes"], + "Median-trimmed Genomic total abundance": median_trimmed_sample_genome_stats["total_abundance"], + "Median-trimmed Genomic mean abundance": median_trimmed_sample_genome_stats["mean_abundance"], + "Median-trimmed Genomic median abundance": median_trimmed_sample_genome_stats["median_abundance"], + "Median-trimmed Genome coverage index": ( + median_trimmed_sample_genome_stats["num_hashes"] / len(self.reference_sig) + if len(self.reference_sig) > 0 else 0 + ), + }) + + if self.amplicon_sig is not None: + self.logger.debug("Calculating advanced amplicon statistics.") + # Amplicon stats for median-trimmed sample + median_trimmed_sample_amplicon = median_trimmed_sample_sig & self.amplicon_sig + median_trimmed_sample_amplicon_stats = median_trimmed_sample_amplicon.get_sample_stats + advanced_stats.update({ + "Median-trimmed Amplicon unique k-mers": median_trimmed_sample_amplicon_stats["num_hashes"], + "Median-trimmed Amplicon total abundance": median_trimmed_sample_amplicon_stats["total_abundance"], + "Median-trimmed Amplicon mean abundance": median_trimmed_sample_amplicon_stats["mean_abundance"], + "Median-trimmed Amplicon median abundance": median_trimmed_sample_amplicon_stats["median_abundance"], + "Median-trimmed Amplicon coverage index": ( + median_trimmed_sample_amplicon_stats["num_hashes"] / len(self.amplicon_sig) + if len(self.amplicon_sig) > 0 else 0 + ), + }) + # Additional advanced relative metrics + self.logger.debug("Calculating advanced relative metrics.") + amplicon_stats["Median-trimmed relative coverage"] = ( + advanced_stats["Median-trimmed Amplicon coverage index"] / advanced_stats["Median-trimmed Genome coverage index"] + if advanced_stats["Median-trimmed Genome coverage index"] > 0 else 0 + ) + amplicon_stats["Median-trimmed relative mean abundance"] = ( + advanced_stats["Median-trimmed Amplicon mean abundance"] / advanced_stats["Median-trimmed Genomic mean abundance"] + if advanced_stats["Median-trimmed Genomic mean abundance"] > 0 else 0 + ) + # Update amplicon_stats with advanced metrics + amplicon_stats.update({ + "Median-trimmed relative coverage": amplicon_stats["Median-trimmed relative coverage"], + "Median-trimmed relative mean abundance": amplicon_stats["Median-trimmed relative mean abundance"], + }) + + advanced_stats.update(amplicon_stats) + + # ============= CHR STATS ============= + + if self.specific_chr_to_sig: + self.logger.debug("Calculating mean abundance for each chromosome.") + for chr_name, chr_sig in self.specific_chr_to_sig.items(): + self.logger.debug("Intersecting %s (%d) with %s (%d)", sample_sig.name, len(sample_sig), chr_name, len(chr_sig)) + chr_sample_sig = sample_sig & chr_sig + chr_stats = chr_sample_sig.get_sample_stats + chr_to_mean_abundance[chr_name] = chr_stats["mean_abundance"] + self.logger.debug("\t-Mean abundance for %s: %f", chr_name, chr_stats["mean_abundance"]) + + # chromosomes are numberd from 1 to ..., sort them by numer (might be string for sex chromosomes) then prefix them with chr- + def sort_chromosomes(chr_name): + try: + # Try to convert to integer for numeric chromosomes + return (0, int(chr_name)) + except ValueError: + # Non-numeric chromosomes (like 'x', 'y', 'z', etc.) + return (1, chr_name) + + + # Create a new dictionary with sorted chromosome names and prefixed with 'chr-' + sorted_chr_to_mean_abundance = { + f"chr-{chr_name.replace('sex-','').replace('autosome-','')}": chr_to_mean_abundance[chr_name] + for chr_name in sorted(chr_to_mean_abundance, key=sort_chromosomes) + } + + chrs_stats.update(sorted_chr_to_mean_abundance) + + # chr_to_mean_abundance but without any chr with partial name sex + autosomal_chr_to_mean_abundance = {} + for chr_name, mean_abundance in chr_to_mean_abundance.items(): + if "sex" in chr_name.lower() or "-snipegenome" in chr_name.lower(): + continue + autosomal_chr_to_mean_abundance[chr_name] = mean_abundance + + + # calculate the CV for the whole sample + if autosomal_chr_to_mean_abundance: + mean_abundances = np.array(list(autosomal_chr_to_mean_abundance.values()), dtype=float) + cv = np.std(mean_abundances) / np.mean(mean_abundances) if np.mean(mean_abundances) != 0 else 0.0 + chrs_stats.update({"Autosomal_CV": cv}) + self.logger.debug("Calculated Autosomal CV: %f", cv) + else: + self.logger.warning("No autosomal chromosomes were processed. 'Autosomal_CV' set to None.") + chrs_stats.update({"Autosomal_CV": None}) + + # ============= SEX STATS ============= + + # Ensure that the chromosome X signature exists + + self.logger.debug("Length of genome before subtracting sex chromosomes %s", len(self.reference_sig)) + autosomals_genome_sig = self.reference_sig.copy() + for chr_name, chr_sig in self.specific_chr_to_sig.items(): + if "sex" in chr_name.lower(): + self.logger.debug("Removing %s chromosome from the autosomal genome signature.", chr_name) + self.logger.debug("Type of autosomals_genome_sig: %s | Type of chr_sig: %s", autosomals_genome_sig.sigtype, chr_sig.sigtype) + self.logger.debug("Length of autosomals_genome_sig: %s | Length of chr_sig: %s", len(autosomals_genome_sig), len(chr_sig)) + autosomals_genome_sig -= chr_sig + self.logger.debug("Length of genome after subtracting sex chromosomes %s", len(autosomals_genome_sig)) + + if 'sex-x' not in self.specific_chr_to_sig: + self.logger.warning("Chromosome X ('sex-x') not found in the provided signatures. X-Ploidy score will be set to zero.") + # set sex-x to an empty signature + self.specific_chr_to_sig['sex-x'] = SnipeSig.create_from_hashes_abundances( + hashes=np.array([], dtype=np.uint64), + abundances=np.array([], dtype=np.uint32), + ksize= self.specific_chr_to_sig[list( self.specific_chr_to_sig.keys())[0]].ksize, + scale= self.specific_chr_to_sig[list( self.specific_chr_to_sig.keys())[0]].scale, + ) + else: + self.logger.debug("X chromosome ('sex-x') detected.") + + + # Separate the autosomal genome signature from chromosome-specific signatures + + #! autosomal sig for now is the all of the genome minus sex chrs + self.logger.debug("Separating autosomal genome signature from chromosome-specific signatures.") + + + # Derive the X chromosome-specific signature by subtracting autosomal genome hashes + specific_xchr_sig = self.specific_chr_to_sig["sex-x"] - autosomals_genome_sig + self.logger.debug("\t-Derived X chromosome-specific signature size: %d hashes.", len(specific_xchr_sig)) + + # Intersect the sample signature with chromosome-specific signatures + sample_specific_xchr_sig = sample_sig & specific_xchr_sig + if len(sample_specific_xchr_sig) == 0: + self.logger.warning("No X chromosome-specific k-mers found in the sample signature.") + self.logger.debug("\t-Intersected sample signature with X chromosome-specific k-mers = %d hashes.", len(sample_specific_xchr_sig)) + sample_autosomal_sig = sample_sig & autosomals_genome_sig + self.logger.debug("\t-Intersected sample signature with autosomal genome k-mers = %d hashes.", len(sample_autosomal_sig)) + + # Retrieve mean abundances + xchr_mean_abundance = sample_specific_xchr_sig.get_sample_stats.get("mean_abundance", 0.0) + autosomal_mean_abundance = sample_autosomal_sig.get_sample_stats.get("mean_abundance", 0.0) + + # Calculate X-Ploidy score + if autosomal_mean_abundance == 0: + self.logger.warning("Autosomal mean abundance is zero. Setting X-Ploidy score to zero to avoid division by zero.") + xploidy_score = 0.0 + else: + xploidy_score = (xchr_mean_abundance / autosomal_mean_abundance) * \ + (len(autosomals_genome_sig) / len(specific_xchr_sig) if len(specific_xchr_sig) > 0 else 0.0) + + self.logger.debug("Calculated X-Ploidy score: %.4f", xploidy_score) + sex_stats.update({"X-Ploidy score": xploidy_score}) + + # Calculate Y-Coverage if Y chromosome is present + if 'sex-y' in self.specific_chr_to_sig and 'sex-x' in self.specific_chr_to_sig: + self.logger.debug("Calculating Y-Coverage based on Y chromosome-specific k-mers.") + + # Derive Y chromosome-specific k-mers by excluding autosomal and X chromosome k-mers + ychr_specific_kmers = self.specific_chr_to_sig["sex-y"] - autosomals_genome_sig - specific_xchr_sig + self.logger.debug("\t-Derived Y chromosome-specific signature size: %d hashes.", len(ychr_specific_kmers)) + + # Intersect Y chromosome-specific k-mers with the sample signature + ychr_in_sample = sample_sig & ychr_specific_kmers + self.logger.debug("\t-Intersected sample signature with Y chromosome-specific k-mers = %d hashes.", len(ychr_in_sample)) + if len(ychr_in_sample) == 0: + self.logger.warning("No Y chromosome-specific k-mers found in the sample signature.") + + # Derive autosomal-specific k-mers by excluding X and Y chromosome k-mers from the reference signature + autosomals_specific_kmers = self.reference_sig - self.specific_chr_to_sig["sex-x"] - self.specific_chr_to_sig['sex-y'] + + # Calculate Y-Coverage metric + if len(ychr_specific_kmers) == 0 or len(autosomals_specific_kmers) == 0: + self.logger.warning("Insufficient k-mers for Y-Coverage calculation. Setting Y-Coverage to zero.") + ycoverage = 0.0 + else: + ycoverage = (len(ychr_in_sample) / len(ychr_specific_kmers)) / \ + (len(sample_autosomal_sig) / len(autosomals_specific_kmers)) + + self.logger.debug("Calculated Y-Coverage: %.4f", ycoverage) + sex_stats.update({"Y-Coverage": ycoverage}) + else: + self.logger.warning("No Y chromosome-specific signature detected. Y-Coverage will be set to zero.") + + # ============= VARS NONREF STATS ============= + if self.variance_sigs: + self.logger.debug("Consuming non-reference k-mers from provided variables.") + self.logger.debug("\t-Current size of the sample signature: %d hashes.", len(sample_sig)) + + sample_nonref = sample_sig - self.reference_sig + + self.logger.debug("\t-Size of non-reference k-mers in the sample signature: %d hashes.", len(sample_nonref)) + sample_nonref.trim_singletons() + self.logger.debug("\t-Size of non-reference k-mers after trimming singletons: %d hashes.", len(sample_nonref)) + + sample_nonref_unique_hashes = len(sample_nonref) + + self.logger.debug("\t-Size of non-reference k-mers in the sample signature: %d hashes.", len(sample_nonref)) + if len(sample_nonref) == 0: + self.logger.warning("No non-reference k-mers found in the sample signature.") + return {} + + # intersect and report coverage and depth, then subtract from sample_nonref so sum will be 100% + for variance_sig in self.variance_sigs: + variance_name = variance_sig.name + sample_nonref_var: SnipeSig = sample_nonref & variance_sig + sample_nonref_var_total_abundance = sample_nonref_var.total_abundance + sample_nonref_var_unique_hashes = len(sample_nonref_var) + sample_nonref_var_coverage_index = sample_nonref_var_unique_hashes / sample_nonref_unique_hashes + vars_nonref_stats.update({ + f"{variance_name} non-genomic total k-mer abundance": sample_nonref_var_total_abundance, + f"{variance_name} non-genomic coverage index": sample_nonref_var_coverage_index + }) + + self.logger.debug("\t-Consuming non-reference k-mers from variable '%s'.", variance_name) + sample_nonref -= sample_nonref_var + self.logger.debug("\t-Size of remaining non-reference k-mers in the sample signature: %d hashes.", len(sample_nonref)) + + vars_nonref_stats["non-var non-genomic total k-mer abundance"] = sample_nonref.total_abundance + vars_nonref_stats["non-var non-genomic coverage index"] = len(sample_nonref) / sample_nonref_unique_hashes if sample_nonref_unique_hashes > 0 else 0.0 + + self.logger.debug( + "After consuming all vars from the non reference k-mers, the size of the sample signature is: %d hashes, " + "with total abundance of %s.", + len(sample_nonref), sample_nonref.total_abundance + ) + + + # ============= Coverage Prediction (ROI) ============= + + if predict_extra_folds: + predicted_fold_coverage = {} + predicted_fold_delta_coverage = {} + nparts = 30 + if isinstance(self.amplicon_sig, SnipeSig): + roi_reference_sig = self.amplicon_sig + self.logger.debug("Using amplicon signature as ROI reference.") + else: + roi_reference_sig = self.reference_sig + self.logger.debug("Using reference genome signature as ROI reference.") + + # Get sample signature intersected with the reference + _sample_sig_genome = sample_sig & self.reference_sig + hashes = _sample_sig_genome.hashes + abundances = _sample_sig_genome.abundances + N = len(hashes) + + # Generate random fractions using Dirichlet distribution + fractions = np.random.dirichlet([1] * nparts, size=N) # Shape: (N, nparts) + + # Calculate counts for each part + counts = np.round(abundances[:, None] * fractions).astype(int) # Shape: (N, nparts) + + # Adjust counts to ensure sums match original abundances + differences = abundances - counts.sum(axis=1) + indices = np.argmax(counts, axis=1) + counts[np.arange(N), indices] += differences + + # Compute cumulative counts + counts_cumulative = counts.cumsum(axis=1) # Shape: (N, nparts) + cumulative_total_abundances = counts.sum(axis=0).cumsum() + + coverage_depth_data = [] + + # Force conversion to GENOME + roi_reference_sig.sigtype = SigType.GENOME + + for i in range(nparts): + cumulative_counts = counts_cumulative[:, i] + idx = cumulative_counts > 0 + + cumulative_hashes = hashes[idx] + cumulative_abundances = cumulative_counts[idx] + + cumulative_snipe_sig = SnipeSig.create_from_hashes_abundances( + hashes=cumulative_hashes, + abundances=cumulative_abundances, + ksize=sample_sig.ksize, + scale=sample_sig.scale, + name=f"{sample_sig.name}_cumulative_part_{i+1}", + filename=sample_sig.filename, + enable_logging=self.enable_logging + ) + + # Compute coverage index + cumulative_qc = ReferenceQC( + sample_sig=cumulative_snipe_sig, + reference_sig=roi_reference_sig, + enable_logging=self.enable_logging + ) + cumulative_stats = cumulative_qc.get_aggregated_stats() + cumulative_coverage_index = cumulative_stats.get("Genome coverage index", 0.0) + cumulative_total_abundance = cumulative_total_abundances[i] + + coverage_depth_data.append({ + "cumulative_parts": i + 1, + "cumulative_total_abundance": cumulative_total_abundance, + "cumulative_coverage_index": cumulative_coverage_index, + }) + + self.logger.debug("Added coverage depth data for cumulative part %d.", i + 1) + + self.logger.debug("Coverage vs depth calculation completed.") + + for extra_fold in predict_extra_folds: + if extra_fold < 1: + self.warning.error("Extra fold must be >= 1. Skipping this extra fold prediction.") + continue + + # Extract cumulative total abundance and coverage index + x_data = np.array([d["cumulative_total_abundance"] for d in coverage_depth_data]) + y_data = np.array([d["cumulative_coverage_index"] for d in coverage_depth_data]) + + # Saturation model function + def saturation_model(x, a, b): + return a * x / (b + x) + + # Initial parameter guesses + initial_guess = [y_data[-1], x_data[int(len(x_data) / 2)]] + + # Fit the model to the data + try: + with warnings.catch_warnings(): + warnings.simplefilter("error", OptimizeWarning) + params, covariance = curve_fit( + saturation_model, + x_data, + y_data, + p0=initial_guess, + bounds=(0, np.inf), + maxfev=10000 + ) + except (RuntimeError, OptimizeWarning) as exc: + self.logger.error("Curve fitting failed.") + raise RuntimeError("Saturation model fitting failed. Cannot predict coverage.") from exc + + # Check if covariance contains inf or nan + if np.isinf(covariance).any() or np.isnan(covariance).any(): + self.logger.error("Covariance of parameters could not be estimated.") + raise RuntimeError("Saturation model fitting failed. Cannot predict coverage.") + + a, b = params + + # Predict coverage at increased sequencing depth + total_abundance = x_data[-1] + predicted_total_abundance = total_abundance * (1 + extra_fold) + predicted_coverage = saturation_model(predicted_total_abundance, a, b) + + # Ensure the predicted coverage does not exceed maximum possible coverage + max_coverage = 1.0 # Coverage index cannot exceed 1 + predicted_coverage = min(predicted_coverage, max_coverage) + predicted_fold_coverage[f"Predicted coverage with {extra_fold} extra folds"] = predicted_coverage + _delta_coverage = predicted_coverage - y_data[-1] + predicted_fold_delta_coverage[f"Predicted delta coverage with {extra_fold} extra folds"] = _delta_coverage + if _delta_coverage < 0: + self.logger.warning( + "Predicted coverage at %.2f-fold increase is less than the current coverage (probably low complexity).", + extra_fold + ) + self.logger.debug("Predicted coverage at %.2f-fold increase: %f", extra_fold, predicted_coverage) + self.logger.debug("Predicted delta coverage at %.2f-fold increase: %f", extra_fold, _delta_coverage) + + # Update the ROI stats + roi_stats.update(predicted_fold_coverage) + roi_stats.update(predicted_fold_delta_coverage) + + # ============= Merging all stats in one dictionary ============= + aggregated_stats = {} + if sample_stats: + aggregated_stats.update(sample_stats) + if genome_stats: + aggregated_stats.update(genome_stats) + if amplicon_stats: + aggregated_stats.update(amplicon_stats) + if advanced_stats: + aggregated_stats.update(advanced_stats) + if chrs_stats: + aggregated_stats.update(chrs_stats) + if sex_stats: + aggregated_stats.update(sex_stats) + if predicted_error_contamination_index: + aggregated_stats.update(predicted_error_contamination_index) + if vars_nonref_stats: + aggregated_stats.update(vars_nonref_stats) + if roi_stats: + aggregated_stats.update(roi_stats) + + # update the class with the new sample + self.sample_to_stats[sample_sig.name] = aggregated_stats + + return aggregated_stats diff --git a/src/snipe/api/reference_QC.py b/src/snipe/api/reference_QC.py index b6179b9..1f0b2c0 100644 --- a/src/snipe/api/reference_QC.py +++ b/src/snipe/api/reference_QC.py @@ -696,7 +696,8 @@ def split_sig_randomly(self, n: int) -> List[SnipeSig]: self.logger.debug("No cached splits found for n=%d. Proceeding to split.", n) # Get k-mers and abundances - hash_to_abund = dict(zip(self.sample_sig.hashes, self.sample_sig.abundances)) + _sample_genome = self.sample_sig & self.reference_sig + hash_to_abund = dict(zip(_sample_genome.hashes, _sample_genome.abundances)) random_split_sigs = self.distribute_kmers_random(hash_to_abund, n) split_sigs = [ SnipeSig.create_from_hashes_abundances( @@ -1555,171 +1556,3 @@ def load_genome_sig_to_dict(self, *, zip_file_path: str, **kwargs) -> Dict[str, return genome_chr_name_to_sig - - -class PreparedQC(ReferenceQC): - r""" - Class for quality control (QC) analysis of sample signature against prepared snipe profiles. - """ - - def __init__(self, *, sample_sig: SnipeSig, snipe_db_path: str = '~/.snipe/dbs/', ref_id: Optional[str] = None, amplicon_id: Optional[str] = None, enable_logging: bool = False, **kwargs): - """ - Initialize the PreparedQC instance. - - **Parameters** - - - `sample_sig` (`SnipeSig`): The sample k-mer signature. - - `snipe_db_path` (`str`): Path to the local Snipe database directory. - - `ref_id` (`Optional[str]`): Reference identifier for selecting specific profiles. - - `enable_logging` (`bool`): Flag to enable detailed logging. - - `**kwargs`: Additional keyword arguments. - """ - self.snipe_db_path = os.path.expanduser(snipe_db_path) - self.ref_id = ref_id - - # Ensure the local database directory exists - os.makedirs(self.snipe_db_path, exist_ok=True) - if enable_logging: - self.logger.debug(f"Local Snipe DB path set to: {self.snipe_db_path}") - else: - self.logger.debug("Logging is disabled for PreparedQC.") - - # Initialize without a reference signature for now; it can be set after downloading - super().__init__( - sample_sig=sample_sig, - reference_sig=None, # To be set after downloading - enable_logging=enable_logging, - **kwargs - ) - - def download_osf_db(self, url: str, save_path: str = '~/.snipe/dbs', force: bool = False) -> Optional[str]: - """ - Download a file from OSF using the provided URL. The file is saved with its original name - as specified by the OSF server via the Content-Disposition header. - - **Parameters** - - - `url` (`str`): The OSF URL to download the file from. - - `save_path` (`str`): The directory path where the file will be saved. Supports user (~) and environment variables. - Default is the local Snipe database directory. - - `force` (`bool`): If True, overwrite the file if it already exists. Default is False. - - **Returns** - - - `Optional[str]`: The path to the downloaded file if successful, else None. - - **Raises** - - - `requests.exceptions.RequestException`: If an error occurs during the HTTP request. - - `Exception`: For any other exceptions that may arise. - """ - try: - # Expand user (~) and environment variables in save_path - expanded_save_path = os.path.expanduser(os.path.expandvars(save_path)) - self.logger.debug(f"Expanded save path: {expanded_save_path}") - - # Ensure the download URL ends with '/download' - parsed_url = urlparse(url) - if not parsed_url.path.endswith('/download'): - download_url = f"{url.rstrip('/')}/download" - else: - download_url = url - - self.logger.debug(f"Download URL: {download_url}") - - # Ensure the save directory exists - os.makedirs(expanded_save_path, exist_ok=True) - self.logger.debug(f"Save path verified/created: {expanded_save_path}") - - # Initiate the GET request with streaming - with requests.get(download_url, stream=True, allow_redirects=True) as response: - response.raise_for_status() # Raise an exception for HTTP errors - - # Attempt to extract filename from Content-Disposition - content_disposition = response.headers.get('Content-Disposition') - filename = self._extract_filename(content_disposition, parsed_url.path) - self.logger.debug(f"Filename determined: {filename}") - - # Define the full save path - full_save_path = os.path.join(expanded_save_path, filename) - self.logger.debug(f"Full save path: {full_save_path}") - - # Check if the file already exists - if os.path.exists(full_save_path): - if force: - self.logger.info(f"Overwriting existing file: {full_save_path}") - else: - self.logger.info(f"File already exists: {full_save_path}. Skipping download.") - return full_save_path - - # Get the total file size for the progress bar - total_size = int(response.headers.get('Content-Length', 0)) - - # Initialize the progress bar - with open(full_save_path, 'wb') as file, tqdm( - total=total_size, - unit='B', - unit_scale=True, - unit_divisor=1024, - desc=filename, - ncols=100 - ) as bar: - for chunk in response.iter_content(chunk_size=1024): - if chunk: # Filter out keep-alive chunks - file.write(chunk) - bar.update(len(chunk)) - - self.logger.info(f"File downloaded successfully: {full_save_path}") - return full_save_path - - except requests.exceptions.RequestException as req_err: - self.logger.error(f"Request error occurred while downloading {url}: {req_err}") - raise - except Exception as e: - self.logger.error(f"An unexpected error occurred while downloading {url}: {e}") - raise - - def _extract_filename(self, content_disposition: Optional[str], url_path: str) -> str: - """ - Extract filename from Content-Disposition header or fallback to URL path. - - **Parameters** - - - `content_disposition` (`Optional[str]`): The Content-Disposition header value. - - `url_path` (`str`): The path component of the URL. - - **Returns** - - - `str`: The extracted filename. - """ - filename = None - if content_disposition: - self.logger.debug("Parsing Content-Disposition header for filename.") - parts = content_disposition.split(';') - for part in parts: - part = part.strip() - if part.lower().startswith('filename*='): - # Handle RFC 5987 encoding (e.g., filename*=UTF-8''example.txt) - encoded_filename = part.split('=', 1)[1].strip() - if "''" in encoded_filename: - filename = encoded_filename.split("''", 1)[1] - else: - filename = encoded_filename - self.logger.debug(f"Filename extracted from headers (RFC 5987): {filename}") - break - elif part.lower().startswith('filename='): - # Remove 'filename=' and any surrounding quotes - filename = part.split('=', 1)[1].strip(' "') - self.logger.debug(f"Filename extracted from headers: {filename}") - break - - if not filename: - self.logger.debug("Falling back to filename derived from URL path.") - filename = os.path.basename(url_path) - if not filename: - filename = 'downloaded_file' - self.logger.debug(f"Filename derived from URL: {filename}") - - return filename - - diff --git a/src/snipe/api/snipe_sig.py b/src/snipe/api/snipe_sig.py index df5633b..08b00d0 100644 --- a/src/snipe/api/snipe_sig.py +++ b/src/snipe/api/snipe_sig.py @@ -1,9 +1,8 @@ import heapq import logging -import sourmash.save_load from snipe.api.enums import SigType -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Dict, Iterator, List, Union, Optional import numpy as np import sourmash import os @@ -18,8 +17,9 @@ class SnipeSig: such as customized set operations and abundance management. """ - def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature], - ksize: int = 51, scale: int = 10000, sig_type=SigType.SAMPLE, enable_logging: bool = False, **kwargs): + def __init__(self, *, + sourmash_sig: Union[str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature], + sig_type=SigType.SAMPLE, enable_logging: bool = False, **kwargs): r""" Initialize the SnipeSig with a sourmash signature object or a path to a signature. @@ -54,15 +54,15 @@ def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignat # Initialize internal variables self.logger.debug("Initializing SnipeSig with sourmash_sig: %s", sourmash_sig) - self._scale = scale - self._ksize = ksize - self._md5sum = None + self._scale: int = None + self._ksize: int = None + self._md5sum: str = None self._hashes = np.array([], dtype=np.uint64) self._abundances = np.array([], dtype=np.uint32) - self._type = sig_type - self._name = None - self._filename = None - self._track_abundance = False + self._type: SigType = sig_type + self._name: str = None + self._filename: str = None + self._track_abundance: bool = True sourmash_sigs: Dict[str, sourmash.signature.SourmashSignature] = {} _sourmash_sig: Union[sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature] = None @@ -117,18 +117,19 @@ def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignat self.logger.debug(f"Iterating over signature: {signame}") if signame.endswith("-snipegenome"): sig = sig.to_mutable() + # self.chr_to_sig[sig.name] = SnipeSig(sourmash_sig=sig, sig_type=SigType.GENOME, enable_logging=enable_logging) sig.name = sig.name.replace("-snipegenome", "") self.logger.debug("Found a genome signature with the snipe suffix `-snipegenome`. Restoring original name `%s`.", sig.name) _sourmash_sig = sig elif signame.startswith("sex-"): self.logger.debug("Found a sex chr signature %s", signame) sig = sig.to_mutable() - sig.name = signame.replace("sex-","") + # sig.name = signame.replace("sex-","") self.chr_to_sig[sig.name] = SnipeSig(sourmash_sig=sig, sig_type=SigType.AMPLICON, enable_logging=enable_logging) elif signame.startswith("autosome-"): self.logger.debug("Found an autosome signature %s", signame) sig = sig.to_mutable() - sig.name = signame.replace("autosome-","") + # sig.name = signame.replace("autosome-","") self.chr_to_sig[sig.name] = SnipeSig(sourmash_sig=sig, sig_type=SigType.AMPLICON, enable_logging=enable_logging) else: continue @@ -281,6 +282,13 @@ def sigtype(self, sigtype: SigType): Set the type of the signature. """ self._type = sigtype + + @track_abundance.setter + def track_abundance(self, track_abundance: bool): + r""" + Set whether the signature tracks abundance. + """ + self._track_abundance = track_abundance def get_info(self) -> dict: r""" @@ -490,7 +498,10 @@ def _convert_to_sourmash_signature(self): self.logger.debug("Converting SnipeSig to sourmash.signature.SourmashSignature.") mh = sourmash.minhash.MinHash(n=0, ksize=self._ksize, scaled=self._scale, track_abundance=self._track_abundance) - mh.set_abundances(dict(zip(self._hashes, self._abundances))) + if self._track_abundance: + mh.set_abundances(dict(zip(self._hashes, self._abundances))) + else: + mh.add_many(self._hashes) self.sourmash_sig = sourmash.signature.SourmashSignature(mh, name=self._name, filename=self._filename) self.logger.debug("Conversion to sourmash.signature.SourmashSignature completed.") @@ -516,7 +527,7 @@ def export(self, path, force=False) -> None: with sourmash.save_load.SaveSignatures_ZipFile(path) as save_sigs: save_sigs.add(self.sourmash_sig) except Exception as e: - logging.error("Failed to export signatures to zip: %s", e) + self.logger.error("Failed to export signatures to zip: %s", e) raise Exception(f"Failed to export signatures to zip: {e}") from e else: raise ValueError("Output file must be either a .sig or .zip file.") @@ -1273,6 +1284,7 @@ def reset_abundance(self, new_abundance: int = 1): self._validate_abundance_operation(new_abundance, "reset abundance") self._abundances[:] = new_abundance + self.track_abundance = True self.logger.debug("Reset all abundances to %d.", new_abundance) def keep_min_abundance(self, min_abundance: int): diff --git a/src/snipe/cli/cli_qc.py b/src/snipe/cli/cli_qc.py index 77852ce..3136d78 100644 --- a/src/snipe/cli/cli_qc.py +++ b/src/snipe/cli/cli_qc.py @@ -2,7 +2,7 @@ import sys import time import logging -from typing import Optional, Any, List, Dict, Set +from typing import Optional, Any, List, Dict, Set, Union import click import pandas as pd @@ -10,125 +10,20 @@ import concurrent.futures from snipe.api.enums import SigType -from snipe.api.sketch import SnipeSketch from snipe.api.snipe_sig import SnipeSig -from snipe.api.reference_QC import ReferenceQC +from snipe.api.multisig_reference_QC import MultiSigReferenceQC -import concurrent.futures -import signal - -def process_sample(sample_path: str, ref_path: str, amplicon_path: Optional[str], - advanced: bool, roi: bool, debug: bool, - ychr: Optional[str] = None, - vars_paths: Optional[List[str]] = None) -> Dict[str, Any]: - """ - Process a single sample for QC. - - Parameters: - - sample_path (str): Path to the sample signature file. - - ref_path (str): Path to the reference signature file. - - amplicon_path (Optional[str]): Path to the amplicon signature file. - - advanced (bool): Flag to include advanced metrics. - - roi (bool): Flag to calculate ROI. - - debug (bool): Flag to enable debugging. - - vars_paths (Optional[List[str]]): List of variable signature file paths. - - Returns: - - Dict[str, Any]: QC results for the sample. - """ - # Configure worker-specific logging - logger = logging.getLogger(f'snipe_qc_worker_{os.path.basename(sample_path)}') - logger.setLevel(logging.DEBUG if debug else logging.INFO) - if not logger.hasHandlers(): - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(logging.DEBUG if debug else logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - handler.setFormatter(formatter) - logger.addHandler(handler) - sample_name = os.path.splitext(os.path.basename(sample_path))[0] - try: - # Load sample signature - sample_sig = SnipeSig(sourmash_sig=sample_path, sig_type=SigType.SAMPLE, enable_logging=debug) - logger.debug(f"Loaded sample signature: {sample_sig.name}") - - # Load reference signature - reference_sig = SnipeSig(sourmash_sig=ref_path, sig_type=SigType.GENOME, enable_logging=debug) - logger.debug(f"Loaded reference signature: {reference_sig.name}") - - # Load amplicon signature if provided - amplicon_sig = None - if amplicon_path: - amplicon_sig = SnipeSig(sourmash_sig=amplicon_path, sig_type=SigType.AMPLICON, enable_logging=debug) - logger.debug(f"Loaded amplicon signature: {amplicon_sig.name}") +def validate_sig_input(ctx, param, value: tuple) -> str: + supported_extensions = ['.zip', '.sig'] + for path in value: + if not os.path.exists(path): + raise click.BadParameter(f"File not found: {path}") + if not any(path.lower().endswith(ext) for ext in supported_extensions): + raise click.BadParameter(f"Unsupported file format: {path}, supported formats: {', '.join(supported_extensions)}") + return value - # Instantiate ReferenceQC - qc_instance = ReferenceQC( - sample_sig=sample_sig, - reference_sig=reference_sig, - amplicon_sig=amplicon_sig, - enable_logging=debug - ) - - # Load variable signatures if provided - if vars_paths: - qc_instance.logger.debug(f"Loading {len(vars_paths)} variable signature(s).") - vars_dict: Dict[str, SnipeSig] = {} - vars_order: List[str] = [] - for path in vars_paths: - qc_instance.logger.debug(f"Loading variable signature from: {path}") - var_sig = SnipeSig(sourmash_sig=path, sig_type=SigType.AMPLICON, enable_logging=debug) - var_name = var_sig.name if var_sig.name else os.path.basename(path) - qc_instance.logger.debug(f"Loaded variable signature: {var_name}") - vars_dict[var_name] = var_sig - vars_order.append(var_name) - logger.debug(f"Loaded variable signature '{var_name}': {var_sig.name}") - # Pass variables to ReferenceQC - qc_instance.nonref_consume_from_vars(vars=vars_dict, vars_order=vars_order) - # No else block needed; variables are optional - - # Calculate chromosome metrics - # genome_chr_to_sig: Dict[str, SnipeSig] = qc_instance.load_genome_sig_to_dict(zip_file_path = ref_path) - chr_to_sig = reference_sig.chr_to_sig.copy() - if ychr: - ychr_sig = SnipeSig(sourmash_sig=ychr, sig_type=SigType.GENOME, enable_logging=debug) - chr_to_sig['y'] = ychr_sig - - qc_instance.calculate_chromosome_metrics(chr_to_sig) - - # Get aggregated stats - aggregated_stats = qc_instance.get_aggregated_stats(include_advanced=advanced) - - # Initialize result dict - result = { - "sample": sample_name, - "file_path": os.path.abspath(sample_path), - } - # Add aggregated stats - result.update(aggregated_stats) - - # Calculate ROI if requested - if roi: - logger.debug(f"Calculating ROI for sample: {sample_name}") - for fold in [1, 2, 5, 9]: - try: - predicted_coverage = qc_instance.predict_coverage(extra_fold=fold) - result[f"Predicted_Coverage_Fold_{fold}x"] = predicted_coverage - logger.debug(f"Fold {fold}x: Predicted Coverage = {predicted_coverage}") - except RuntimeError as e: - logger.error(f"ROI calculation failed for sample {sample_name} at fold {fold}x: {e}") - result[f"Predicted_Coverage_Fold_{fold}x"] = None - - return result - except Exception as e: - logger.error(f"QC failed for sample {sample_path}: {e}") - return { - "sample": sample_name, - "file_path": os.path.abspath(sample_path), - "QC_Error": str(e) - } - def validate_tsv_file(ctx, param, value: str) -> str: if not value.lower().endswith('.tsv'): raise click.BadParameter('Output file must have a .tsv extension.') @@ -137,18 +32,17 @@ def validate_tsv_file(ctx, param, value: str) -> str: @click.command() @click.option('--ref', type=click.Path(exists=True), required=True, help='Reference genome signature file (required).') -@click.option('--sample', type=click.Path(exists=True), multiple=True, help='Sample signature file. Can be provided multiple times.') +@click.option('--sample', type=click.Path(exists=True), callback=validate_sig_input, multiple=True, default = None, help='Sample signature file. Can be provided multiple times.') @click.option('--samples-from-file', type=click.Path(exists=True), help='File containing sample paths (one per line).') @click.option('--amplicon', type=click.Path(exists=True), help='Amplicon signature file (optional).') @click.option('--roi', is_flag=True, default=False, help='Calculate ROI for 1,2,5,9 folds.') -@click.option('--cores', '-c', default=4, type=int, show_default=True, help='Number of CPU cores to use for parallel processing.') @click.option('--advanced', is_flag=True, default=False, help='Include advanced QC metrics.') @click.option('--ychr', type=click.Path(exists=True), help='Y chromosome signature file (overrides the reference ychr).') @click.option('--debug', is_flag=True, default=False, help='Enable debugging and detailed logging.') @click.option('-o', '--output', required=True, callback=validate_tsv_file, help='Output TSV file for QC results.') @click.option('--var', 'vars', multiple=True, type=click.Path(exists=True), help='Variable signature file path. Can be used multiple times.') def qc(ref: str, sample: List[str], samples_from_file: Optional[str], - amplicon: Optional[str], roi: bool, cores: int, advanced: bool, + amplicon: Optional[str], roi: bool, advanced: bool, ychr: Optional[str], debug: bool, output: str, vars: List[str]): """ Perform quality control (QC) on multiple samples against a reference genome. @@ -178,9 +72,6 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], - `--roi` Calculate ROI for 1x, 2x, 5x, and 9x coverage folds. - - `--cores, -c INTEGER` **[default: 4]** - Number of CPU cores to use for parallel processing. - - `--advanced` Include advanced QC metrics. @@ -245,7 +136,7 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], ### Combining Multiple Options ```bash - snipe qc --ref reference.sig --sample sample1.sig --sample sample2.sig --amplicon amplicon.sig --var var1.sig --var var2.sig --advanced --roi --cores 8 -o qc_results.tsv + snipe qc --ref reference.sig --sample sample1.sig --sample sample2.sig --amplicon amplicon.sig --var var1.sig --var var2.sig --advanced --roi -o qc_results.tsv ``` ## Detailed Use Cases @@ -384,12 +275,10 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], ### Use Case 7: Combining Multiple Options for Comprehensive QC - **Objective:** Perform a comprehensive QC that includes multiple samples, amplicon signature, variable signatures, advanced metrics, and ROI calculations using multiple CPU cores for efficiency. - **Command:** ```bash - snipe qc --ref reference.sig --sample sample1.sig --sample sample2.sig --amplicon amplicon.sig --var var1.sig --var var2.sig --advanced --roi --cores 8 -o qc_comprehensive.tsv + snipe qc --ref reference.sig --sample sample1.sig --sample sample2.sig --amplicon amplicon.sig --var var1.sig --var var2.sig --advanced --roi -o qc_comprehensive.tsv ``` **Explanation:** @@ -400,7 +289,6 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], - `--var var1.zip` & `--var var2.zip`: Variable signature files. - `--advanced`: Includes advanced QC metrics. - `--roi`: Enables ROI calculations. - - `--cores 8`: Utilizes 8 CPU cores for parallel processing. - `-o qc_comprehensive.tsv`: Output file for QC results. **Expected Output:** @@ -408,6 +296,8 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], A TSV file named `qc_comprehensive.tsv` containing comprehensive QC metrics, including advanced analyses, ROI predictions, and data from amplicon and variable signatures for both `sample1.sig` and `sample2.sig`. """ + print(sample) + start_time = time.time() # Configure logging @@ -423,7 +313,12 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], logger.info("Starting QC process.") # Collect sample paths from --sample and --samples-from-file - samples_set: Set[str] = set(sample) # Start with samples provided via --sample + samples_set: Set[str] = set() + if sample: + for _sample in sample: + logger.debug(f"Adding sample from command-line: {_sample}") + samples_set.add(_sample) + if samples_from_file: logger.debug(f"Reading samples from file: {samples_from_file}") @@ -469,9 +364,21 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], except Exception as e: logger.error(f"Failed to load amplicon signature from {amplicon}: {e}") sys.exit(1) + + # Load Y chromosome signature if provided + ychr_sig = None + if ychr: + logger.info(f"Loading Y chromosome signature from: {ychr}") + try: + ychr_sig = SnipeSig(sourmash_sig=ychr, sig_type=SigType.GENOME, enable_logging=debug) + logger.debug(f"Loaded Y chromosome signature: {ychr_sig.name}") + except Exception as e: + logger.error(f"Failed to load Y chromosome signature from {ychr}: {e}") + sys.exit(1) # Prepare variable signatures if provided vars_paths = [] + vars_snipesigs = [] if vars: logger.info(f"Loading {len(vars)} variable signature(s).") for path in vars: @@ -479,73 +386,59 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], logger.error(f"Variable signature file does not exist: {path}") sys.exit(1) vars_paths.append(os.path.abspath(path)) + try: + var_sig = SnipeSig(sourmash_sig=path, sig_type=SigType.AMPLICON, enable_logging=debug) + vars_snipesigs.append(var_sig) + logger.debug(f"Loaded variable signature: {var_sig.name}") + except Exception as e: + logger.error(f"Failed to load variable signature from {path}: {e}") + logger.debug(f"Variable signature paths: {vars_paths}") + - # Prepare arguments for process_sample function - dict_process_args = [] - for sample_path in valid_samples: - dict_process_args.append({ - "sample_path": sample_path, - "ref_path": ref, - "amplicon_path": amplicon, - "advanced": advanced, - "roi": roi, - "debug": debug, - "ychr": ychr, - "vars_paths": vars_paths - }) - - results = [] + predict_extra_folds = [1, 2, 5, 9] + + + qc_instance = MultiSigReferenceQC( + reference_sig=reference_sig, + amplicon_sig=amplicon_sig, + ychr=ychr_sig if ychr_sig else None, + varsigs=vars_snipesigs if vars_snipesigs else None, + enable_logging=debug + ) + + sample_to_stats = {} + failed_samples = [] + for sample_path in tqdm(valid_samples): + sample_sig = SnipeSig(sourmash_sig=sample_path, sig_type=SigType.SAMPLE, enable_logging=debug) + try: + sample_stats = qc_instance.process_sample(sample_sig=sample_sig, + predict_extra_folds = predict_extra_folds if roi else None, + advanced=advanced) + sample_to_stats[sample_sig.name] = sample_stats + except Exception as e: + failed_samples.append(sample_sig.name) + qc_instance.logger.error(f"Failed to process sample {sample_sig.name}: {e}") + continue - # Define a handler for graceful shutdown - def shutdown(signum, frame): - logger.warning("Shutdown signal received. Terminating all worker processes...") - executor.shutdown(wait=False, cancel_futures=True) - sys.exit(1) - - # Register signal handlers - signal.signal(signal.SIGINT, shutdown) - signal.signal(signal.SIGTERM, shutdown) - try: - with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor: - futures = { - executor.submit(process_sample, **args): args for args in dict_process_args - } - - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing samples"): - sample = futures[future] - try: - result = future.result() - results.append(result) - except Exception as exc: - logger.error(f"Sample {sample['sample_path']} generated an exception: {exc}") - results.append({ - "sample": os.path.splitext(os.path.basename(sample['sample_path']))[0], - "file_path": sample['sample_path'], - "QC_Error": str(exc) - }) - except KeyboardInterrupt: - logger.warning("KeyboardInterrupt received. Shutting down...") - sys.exit(1) - except Exception as e: - logger.error(f"An unexpected error occurred: {e}") - sys.exit(1) - # Separate successful and failed results - succeeded = [res for res in results if "QC_Error" not in res] - failed = [res for res in results if "QC_Error" in res] + succeeded = list(sample_to_stats.keys()) + failed = len(failed_samples) # Handle complete failure if len(succeeded) == 0: logger.error("All samples failed during QC processing. Output TSV will not be generated.") sys.exit(1) + + # write total success and failure + logger.info("Successfully processed samples: %d", len(succeeded)) # Prepare the command-line invocation for comments command_invocation = ' '.join(sys.argv) # Create pandas DataFrame for succeeded samples - df = pd.DataFrame(succeeded) + df = pd.DataFrame(sample_to_stats.values()) # Reorder columns to have 'sample' and 'file_path' first, if they exist cols = list(df.columns) diff --git a/src/snipe/cli/cli_qc_parallel.py b/src/snipe/cli/cli_qc_parallel.py new file mode 100644 index 0000000..98a33aa --- /dev/null +++ b/src/snipe/cli/cli_qc_parallel.py @@ -0,0 +1,363 @@ +import os +import sys +import time +import logging +from typing import Optional, Any, List, Dict, Set, Union + +import click +import pandas as pd +from tqdm import tqdm +import concurrent.futures + +from snipe.api.enums import SigType +from snipe.api.sketch import SnipeSketch +from snipe.api.snipe_sig import SnipeSig +from snipe.api.reference_QC import ReferenceQC + +import signal + + +def validate_sig_input(ctx, param, value: tuple) -> str: + supported_extensions = ['.zip', '.sig'] + for path in value: + if not os.path.exists(path): + raise click.BadParameter(f"File not found: {path}") + if not any(path.lower().endswith(ext) for ext in supported_extensions): + raise click.BadParameter(f"Unsupported file format: {path}, supported formats: {', '.join(supported_extensions)}") + +def process_sample(sample_path: str, reference_sig: SnipeSig, amplicon_sig: Optional[SnipeSig], + advanced: bool, roi: bool, debug: bool, + ychr_sig: Optional[SnipeSig] = None, + var_sigs: Optional[List[SnipeSig]] = None) -> Dict[str, Any]: + """ + Process a single sample for QC. + + Parameters: + - sample_path (str): Path to the sample signature file. + - ref_path (str): Path to the reference signature file. + - amplicon_path (Optional[str]): Path to the amplicon signature file. + - advanced (bool): Flag to include advanced metrics. + - roi (bool): Flag to calculate ROI. + - debug (bool): Flag to enable debugging. + - vars_paths (Optional[Union[List[str], List[SnipeSig]]]): List of paths to variable signature files or SnipeSig objects. + + Returns: + - Dict[str, Any]: QC results for the sample. + """ + # Configure worker-specific logging + logger = logging.getLogger(f'snipe_qc_worker_{os.path.basename(sample_path)}') + logger.setLevel(logging.DEBUG if debug else logging.INFO) + if not logger.hasHandlers(): + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if debug else logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + sample_name = os.path.splitext(os.path.basename(sample_path))[0] + try: + # Load sample signature + sample_sig = SnipeSig(sourmash_sig=sample_path, sig_type=SigType.SAMPLE, enable_logging=debug) + logger.debug(f"Loaded sample signature: {sample_sig.name}") + + # Instantiate ReferenceQC + qc_instance = ReferenceQC( + sample_sig=sample_sig, + reference_sig=reference_sig, + amplicon_sig=amplicon_sig, + enable_logging=debug + ) + + # Load variable signatures if provided + if var_sigs: + qc_instance.logger.debug(f"Loading {len(var_sigs)} variable signature(s).") + vars_dict: Dict[str, SnipeSig] = {sig.name: sig for sig in var_sigs} + qc_instance.logger.debug(f"vars_dict: {vars_dict}") + vars_order: List[str] = [] + + qc_instance.logger.debug("Loading variable signature(s) from SnipeSig objects.") + for snipe_name in vars_dict.keys(): + vars_order.append(snipe_name) + qc_instance.logger.debug(f"Loaded variable signature: {snipe_name}") + + # log keys of vars_dict and vars_order + qc_instance.logger.debug(f"vars_dict keys: {vars_dict.keys()}") + qc_instance.logger.debug(f"vars_order: {vars_order}") + + qc_instance.nonref_consume_from_vars(vars=vars_dict, vars_order=vars_order) + + # No else block needed; variables are optional + + # Calculate chromosome metrics + # genome_chr_to_sig: Dict[str, SnipeSig] = qc_instance.load_genome_sig_to_dict(zip_file_path = ref_path) + chr_to_sig = reference_sig.chr_to_sig.copy() + if ychr_sig: + chr_to_sig['y'] = ychr_sig + + qc_instance.calculate_chromosome_metrics(chr_to_sig) + + # Get aggregated stats + aggregated_stats = qc_instance.get_aggregated_stats(include_advanced=advanced) + + # Initialize result dict + result = { + "sample": sample_name, + "file_path": os.path.abspath(sample_path), + } + # Add aggregated stats + result.update(aggregated_stats) + + # Calculate ROI if requested + if roi: + logger.debug(f"Calculating ROI for sample: {sample_name}") + for fold in [1, 2, 5, 9]: + try: + predicted_coverage = qc_instance.predict_coverage(extra_fold=fold) + result[f"Predicted_Coverage_Fold_{fold}x"] = predicted_coverage + logger.debug(f"Fold {fold}x: Predicted Coverage = {predicted_coverage}") + except RuntimeError as e: + logger.error(f"ROI calculation failed for sample {sample_name} at fold {fold}x: {e}") + result[f"Predicted_Coverage_Fold_{fold}x"] = None + + return result + + except Exception as e: + logger.error(f"QC failed for sample {sample_path}: {e}") + return { + "sample": sample_name, + "file_path": os.path.abspath(sample_path), + "QC_Error": str(e) + } + +def validate_tsv_file(ctx, param, value: str) -> str: + if not value.lower().endswith('.tsv'): + raise click.BadParameter('Output file must have a .tsv extension.') + return value + + +@click.command() +@click.option('--ref', type=click.Path(exists=True), required=True, help='Reference genome signature file (required).') +@click.option('--sample', type=click.Path(exists=True), callback=validate_sig_input, multiple=True, help='Sample signature file. Can be provided multiple times.') +@click.option('--samples-from-file', type=click.Path(exists=True), help='File containing sample paths (one per line).') +@click.option('--amplicon', type=click.Path(exists=True), help='Amplicon signature file (optional).') +@click.option('--roi', is_flag=True, default=False, help='Calculate ROI for 1,2,5,9 folds.') +@click.option('--cores', '-c', default=4, type=int, show_default=True, help='Number of CPU cores to use for parallel processing.') +@click.option('--advanced', is_flag=True, default=False, help='Include advanced QC metrics.') +@click.option('--ychr', type=click.Path(exists=True), help='Y chromosome signature file (overrides the reference ychr).') +@click.option('--debug', is_flag=True, default=False, help='Enable debugging and detailed logging.') +@click.option('-o', '--output', required=True, callback=validate_tsv_file, help='Output TSV file for QC results.') +@click.option('--var', 'vars', multiple=True, type=click.Path(exists=True), help='Variable signature file path. Can be used multiple times.') +def parallel_qc(ref: str, sample: List[str], samples_from_file: Optional[str], + amplicon: Optional[str], roi: bool, cores: int, advanced: bool, + ychr: Optional[str], debug: bool, output: str, vars: List[str]): + """ + Parallelized version of the `qc` command (not optimized for memory). + """ + + + start_time = time.time() + + # Configure logging + logger = logging.getLogger('snipe_qc_parallel') + logger.setLevel(logging.DEBUG if debug else logging.INFO) + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if debug else logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + if not logger.hasHandlers(): + logger.addHandler(handler) + + logger.info("Starting QC process.") + + # warning of high memory usage if used large var files + if vars: + logger.warning("Using large variable signature files may result in high memory usage.") + + + # Collect sample paths from --sample and --samples-from-file + samples_set: Set[str] = set() + if sample: + samples_set.update(sample) + + if samples_from_file: + logger.debug(f"Reading samples from file: {samples_from_file}") + try: + with open(samples_from_file, 'r', encoding='utf-8') as f: + file_samples = {line.strip() for line in f if line.strip()} + samples_set.update(file_samples) + logger.debug(f"Collected {len(file_samples)} samples from file.") + except Exception as e: + logger.error(f"Failed to read samples from file {samples_from_file}: {e}") + sys.exit(1) + + # Deduplicate and validate sample paths + valid_samples = [] + for sample_path in samples_set: + if os.path.exists(sample_path): + valid_samples.append(os.path.abspath(sample_path)) + else: + logger.warning(f"Sample file does not exist and will be skipped: {sample_path}") + + if not valid_samples: + logger.error("No valid samples provided for QC.") + sys.exit(1) + + logger.info(f"Total valid samples to process: {len(valid_samples)}") + + # Load reference signature + logger.info(f"Loading reference signature from: {ref}") + try: + reference_sig = SnipeSig(sourmash_sig=ref, sig_type=SigType.GENOME, enable_logging=debug) + logger.debug(f"Loaded reference signature: {reference_sig.name}") + except Exception as e: + logger.error(f"Failed to load reference signature from {ref}: {e}") + sys.exit(1) + + # Load amplicon signature if provided + amplicon_sig = None + if amplicon: + logger.info(f"Loading amplicon signature from: {amplicon}") + try: + amplicon_sig = SnipeSig(sourmash_sig=amplicon, sig_type=SigType.AMPLICON, enable_logging=debug) + logger.debug(f"Loaded amplicon signature: {amplicon_sig.name}") + except Exception as e: + logger.error(f"Failed to load amplicon signature from {amplicon}: {e}") + sys.exit(1) + + # Load Y chromosome signature if provided + ychr_sig = None + if ychr: + logger.info(f"Loading Y chromosome signature from: {ychr}") + try: + ychr_sig = SnipeSig(sourmash_sig=ychr, sig_type=SigType.GENOME, enable_logging=debug) + logger.debug(f"Loaded Y chromosome signature: {ychr_sig.name}") + except Exception as e: + logger.error(f"Failed to load Y chromosome signature from {ychr}: {e}") + sys.exit(1) + + # Prepare variable signatures if provided + vars_paths = [] + vars_snipesigs = [] + if vars: + logger.info(f"Loading {len(vars)} variable signature(s).") + for path in vars: + if not os.path.exists(path): + logger.error(f"Variable signature file does not exist: {path}") + sys.exit(1) + vars_paths.append(os.path.abspath(path)) + try: + var_sig = SnipeSig(sourmash_sig=path, sig_type=SigType.AMPLICON, enable_logging=debug) + vars_snipesigs.append(var_sig) + logger.debug(f"Loaded variable signature: {var_sig.name}") + except Exception as e: + logger.error(f"Failed to load variable signature from {path}: {e}") + + logger.debug(f"Variable signature paths: {vars_paths}") + + # Prepare arguments for process_sample function + dict_process_args = [] + for sample_path in valid_samples: + dict_process_args.append({ + "sample_path": sample_path, + "reference_sig": reference_sig, + "amplicon_sig": amplicon_sig, + "advanced": advanced, + "roi": roi, + "debug": debug, + "ychr_sig": ychr_sig, + "var_sigs": vars_snipesigs #vars_paths + }) + + results = [] + + # Define a handler for forceful shutdown + executor = None + def shutdown(signum, frame): + logger.warning("Shutdown signal received. Terminating all worker processes...") + global executor + if executor: + executor.shutdown(wait=False, cancel_futures=True) + os._exit(1) # Forcefully terminate the program + + # Register signal handlers + signal.signal(signal.SIGINT, shutdown) + signal.signal(signal.SIGTERM, shutdown) + + try: + with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as exec_instance: + executor = exec_instance # Assign the executor instance + futures = { + executor.submit(process_sample, **args): args for args in dict_process_args + } + + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing samples"): + sample = futures[future] + try: + result = future.result() + results.append(result) + except Exception as exc: + logger.error(f"Sample {sample['sample_path']} generated an exception: {exc}") + results.append({ + "sample": os.path.splitext(os.path.basename(sample['sample_path']))[0], + "file_path": sample['sample_path'], + "QC_Error": str(exc) + }) + except KeyboardInterrupt: + logger.warning("KeyboardInterrupt received. Shutting down...") + os._exit(1) # Forcefully terminate the program + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + os._exit(1) # Forcefully terminate the program + finally: + executor = None # Reset executor + + # Separate successful and failed results + succeeded = [res for res in results if "QC_Error" not in res] + failed = [res for res in results if "QC_Error" in res] + + # Handle complete failure + if len(succeeded) == 0: + logger.error("All samples failed during QC processing. Output TSV will not be generated.") + sys.exit(1) + + # write total success and failure + logger.info("Successfully processed samples: %d", len(succeeded)) + + # Prepare the command-line invocation for comments + command_invocation = ' '.join(sys.argv) + + # Create pandas DataFrame for succeeded samples + df = pd.DataFrame(succeeded) + + # Reorder columns to have 'sample' and 'file_path' first, if they exist + cols = list(df.columns) + reordered_cols = [] + for col in ['sample', 'file_path']: + if col in cols: + reordered_cols.append(col) + cols.remove(col) + reordered_cols += cols + df = df[reordered_cols] + + # Export to TSV with comments + try: + with open(output, 'w', encoding='utf-8') as f: + # Write comment with command invocation + f.write(f"# Command: {command_invocation}\n") + # Write the DataFrame to the file + df.to_csv(f, sep='\t', index=False) + logger.info(f"QC results successfully exported to {output}") + except Exception as e: + logger.error(f"Failed to export QC results to {output}: {e}") + sys.exit(1) + + # Report failed samples if any + if failed: + failed_samples = [res['sample'] for res in failed] + logger.warning(f"The following {len(failed_samples)} sample(s) failed during QC processing:") + for sample in failed_samples: + logger.warning(f"- {sample}") + + end_time = time.time() + elapsed_time = end_time - start_time + logger.info(f"QC process completed in {elapsed_time:.2f} seconds.") \ No newline at end of file diff --git a/src/snipe/cli/main.py b/src/snipe/cli/main.py index acd14ef..d6ec11f 100644 --- a/src/snipe/cli/main.py +++ b/src/snipe/cli/main.py @@ -17,6 +17,7 @@ from snipe.cli.cli_qc import qc as cli_qc from snipe.cli.cli_sketch import sketch as cli_sketch from snipe.cli.cli_ops import ops as cli_ops +from snipe.cli.cli_qc_parallel import parallel_qc # pylint: disable=logging-fstring-interpolation @@ -31,11 +32,14 @@ def cli(): Commands: - `sketch`: Perform sketching operations on genomic data. - `qc`: Execute quality control (QC) on multiple samples against a reference genome. + - `parallel-qc`: Parallelized version of the `qc` command (not optimized for memory). + - `ops`: Perform various operations on sketches. """ pass cli.add_command(cli_qc) +cli.add_command(parallel_qc) cli.add_command(cli_sketch) cli.add_command(cli_ops)