From 89cb1d6a3d5413a8d8dc126f3a5a1edb32b141b7 Mon Sep 17 00:00:00 2001 From: Mohamed Abuelanin Date: Sun, 13 Oct 2024 11:22:50 -0700 Subject: [PATCH 1/7] remove dead usecase --- tests/api/test_reference_qc.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/api/test_reference_qc.py b/tests/api/test_reference_qc.py index d98327e..e48c294 100644 --- a/tests/api/test_reference_qc.py +++ b/tests/api/test_reference_qc.py @@ -868,19 +868,6 @@ def test_calculate_coverage_vs_depth_with_empty_splits(self): self.assertEqual(data_point["cumulative_coverage_index"], 0.0) self.assertEqual(data_point["cumulative_total_abundance"], 0) - - def test_predict_coverage_with_very_large_extra_fold(self): - """ - Test that predict_coverage caps the predicted coverage at 1.0 even with a very large extra_fold. - """ - qc = ReferenceQC( - sample_sig=self.sample_sig, - reference_sig=self.reference_sig, - enable_logging=False - ) - predicted_coverage = qc.predict_coverage(extra_fold=1000.0, n=40) - self.assertGreaterEqual(predicted_coverage, 0.9) - self.assertLessEqual(predicted_coverage, 1.0) def test_calculate_sex_chrs_metrics(self): """ From c0a56180fb23f45724cfb39252357a1d7adcfd34 Mon Sep 17 00:00:00 2001 From: Mohamed Abuelanin Date: Sun, 13 Oct 2024 11:26:21 -0700 Subject: [PATCH 2/7] FracMinHash not MinHash as description --- docs/mkdocs.yml | 1 + src/snipe/api/sketch.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 391b741..e05a6a3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -82,6 +82,7 @@ nav: - API: - SnipeSig: SnipeSig.md - ReferenceQC: ReferenceQC.md + - Sketch: Sketch.md - CLI: cli.md diff --git a/src/snipe/api/sketch.py b/src/snipe/api/sketch.py index 2a8b677..941da5f 100644 --- a/src/snipe/api/sketch.py +++ b/src/snipe/api/sketch.py @@ -18,7 +18,7 @@ class SnipeSketch: """ - SnipeSketch is responsible for creating MinHash sketches from genomic data. + SnipeSketch is responsible for creating FracMinHash sketches from genomic data. It supports parallel processing, progress monitoring, and different sketching modes including sample, genome, and amplicon sketching. """ @@ -69,7 +69,7 @@ def process_sequences( scaled: int = 10_000, ) -> sourmash.MinHash: """ - Process a subset of sequences to create a MinHash sketch. + Process a subset of sequences to create a FracMinHash sketch. Each process creates its own MinHash instance and processes sequences assigned based on the thread ID. Progress is reported via a shared queue. @@ -84,7 +84,7 @@ def process_sequences( scaled (int, optional): Scaling factor for MinHash. Defaults to 10_000. Returns: - sourmash.MinHash: The resulting MinHash sketch. + sourmash.MinHash: The resulting FracMinHash sketch. """ self._register_signal_handler() try: @@ -180,7 +180,7 @@ def _sketch_sample( **kwargs: Any, ) -> sourmash.SourmashSignature: """ - Create a MinHash sketch for a sample using parallel processing. + Create a FracMinHash sketch for a sample using parallel processing. Args: sample_name (str): Name of the sample. @@ -465,7 +465,7 @@ def amplicon_sketching( amplicon_name: str = "amplicon", ) -> sourmash.SourmashSignature: """ - Create a MinHash sketch for an amplicon. + Create a FracMinHash sketch for an amplicon. Args: fasta_file (str): Path to the FASTA file. From dbbab13fe126c95f499637cdd8ff0bf86193e139 Mon Sep 17 00:00:00 2001 From: Mohamed Abuelanin Date: Mon, 14 Oct 2024 10:21:26 -0700 Subject: [PATCH 3/7] checkpoint in QC --- src/snipe/cli/main.py | 280 +++++++++++++++++++++++++++++++++++------- 1 file changed, 233 insertions(+), 47 deletions(-) diff --git a/src/snipe/cli/main.py b/src/snipe/cli/main.py index 907c755..132be45 100644 --- a/src/snipe/cli/main.py +++ b/src/snipe/cli/main.py @@ -2,54 +2,35 @@ import sys import time import logging -from typing import Optional, Any +from typing import Optional, Any, List, Dict, Set import click +import pandas as pd +from tqdm import tqdm +import concurrent.futures from snipe.api.enums import SigType from snipe.api.sketch import SnipeSketch +from snipe.api.snipe_sig import SnipeSig +from snipe.api.reference_QC import ReferenceQC +# pylint: disable=logging-fstring-interpolation def validate_zip_file(ctx, param, value: str) -> str: """ Validate that the output file has a .zip extension. - - Args: - ctx: Click context. - param: Click parameter. - value (str): The value of the parameter. - - Raises: - click.BadParameter: If the file does not have a .zip extension. - - Returns: - str: The validated file path. """ if not value.lower().endswith('.zip'): raise click.BadParameter('Output file must have a .zip extension.') return value -def ensure_mutually_exclusive(ctx, param, value: Any) -> Any: +def validate_tsv_file(ctx, param, value: str) -> str: """ - Ensure that only one of --sample, --ref, or --amplicon is provided. - - Args: - ctx: Click context. - param: Click parameter. - value (Any): The value of the parameter. - - Raises: - click.UsageError: If more than one or none of the mutually exclusive options are provided. - - Returns: - Any: The validated value. + Validate that the output file has a .tsv extension. """ - sample, ref, amplicon = ctx.params.get('sample'), ctx.params.get('ref'), ctx.params.get('amplicon') - if sum([bool(sample), bool(ref), bool(amplicon)]) > 1: - raise click.UsageError('Only one of --sample, --ref, or --amplicon can be used at a time.') - if not any([bool(sample), bool(ref), bool(amplicon)]): - raise click.UsageError('You must specify one of --sample, --ref, or --amplicon.') + if not value.lower().endswith('.tsv'): + raise click.BadParameter('Output file must have a .tsv extension.') return value @@ -58,7 +39,7 @@ def cli(): """ Snipe CLI Tool - Use this tool to perform various sketching operations on genomic data. + Use this tool to perform various sketching and quality control operations on genomic data. """ pass @@ -86,7 +67,11 @@ def sketch(ctx, sample: Optional[str], ref: Optional[str], amplicon: Optional[st You must specify exactly one of --sample, --ref, or --amplicon. """ # Ensure mutual exclusivity - ensure_mutually_exclusive(ctx, None, None) + samples = [sample, ref, amplicon] + provided = [s for s in samples if s] + if len(provided) != 1: + click.echo('Error: Exactly one of --sample, --ref, or --amplicon must be provided.') + sys.exit(1) # Handle existing output file if os.path.exists(output_file): @@ -179,8 +164,8 @@ def sketch(ctx, sample: Optional[str], ref: Optional[str], amplicon: Optional[st chr_to_sig[snipe_ychr_name] = y_chr_sig # Log the detected chromosomes - autosomal = [name for name in chr_to_sig.keys() if "autosome" in name] - sex = [name for name in chr_to_sig.keys() if "sex" in name] + autosomal = [name for name in chr_to_sig.keys() if "autosome" in name.lower()] + sex = [name for name in chr_to_sig.keys() if "sex" in name.lower()] click.echo("Autodetected chromosomes:") for i, chr_name in enumerate(chr_to_sig.keys(), 1): @@ -223,26 +208,227 @@ def sketch(ctx, sample: Optional[str], ref: Optional[str], amplicon: Optional[st click.echo(f"Sketching completed in {elapsed_time:.2f} seconds.") -# Add sketch command to cli -cli.add_command(sketch) +# Define the top-level process_sample function +def process_sample(sample_path: str, ref_path: str, amplicon_path: Optional[str], + advanced: bool, roi: bool, debug: bool) -> Dict[str, Any]: + """ + Process a single sample for QC. + Parameters: + - sample_path (str): Path to the sample signature file. + - ref_path (str): Path to the reference signature file. + - amplicon_path (Optional[str]): Path to the amplicon signature file. + - advanced (bool): Flag to include advanced metrics. + - roi (bool): Flag to calculate ROI. + - debug (bool): Flag to enable debugging. -# Example placeholder for future 'qc' command -@cli.group() -def qc(): - """ - Perform quality control operations. + Returns: + - Dict[str, Any]: QC results for the sample. """ - pass + # Configure worker-specific logging + logger = logging.getLogger(f'snipe_qc_worker_{os.path.basename(sample_path)}') + logger.setLevel(logging.DEBUG if debug else logging.INFO) + if not logger.hasHandlers(): + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if debug else logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + sample_name = os.path.splitext(os.path.basename(sample_path))[0] + try: + # Load sample signature + sample_sig = SnipeSig(sourmash_sig=sample_path, sig_type=SigType.SAMPLE, enable_logging=debug) + logger.debug(f"Loaded sample signature: {sample_sig.name}") + + # Load reference signature + reference_sig = SnipeSig(sourmash_sig=ref_path, sig_type=SigType.GENOME, enable_logging=debug) + logger.debug(f"Loaded reference signature: {reference_sig.name}") + + # Load amplicon signature if provided + amplicon_sig = None + if amplicon_path: + amplicon_sig = SnipeSig(sourmash_sig=amplicon_path, sig_type=SigType.AMPLICON, enable_logging=debug) + logger.debug(f"Loaded amplicon signature: {amplicon_sig.name}") + + # Instantiate ReferenceQC + qc_instance = ReferenceQC( + sample_sig=sample_sig, + reference_sig=reference_sig, + amplicon_sig=amplicon_sig, + enable_logging=debug + ) + + # calculate chromosome metrics + qc_instance.calculate_chromosome_metrics() + + # Get aggregated stats + aggregated_stats = qc_instance.get_aggregated_stats(include_advanced=advanced) + + # Initialize result dict + result = { + "sample": sample_name, + "file_path": os.path.abspath(sample_path), + } + # Add aggregated stats + result.update(aggregated_stats) + + # Calculate ROI if requested + if roi: + logger.debug(f"Calculating ROI for sample: {sample_name}") + for fold in [1, 2, 5, 9]: + try: + predicted_coverage = qc_instance.predict_coverage(extra_fold=fold) + result[f"Predicted_Coverage_Fold_{fold}x"] = predicted_coverage + logger.debug(f"Fold {fold}x: Predicted Coverage = {predicted_coverage}") + except RuntimeError as e: + logger.error(f"ROI calculation failed for sample {sample_name} at fold {fold}x: {e}") + result[f"Predicted_Coverage_Fold_{fold}x"] = None + + return result + + except Exception as e: + logger.error(f"QC failed for sample {sample_path}: {e}") + return { + "sample": sample_name, + "file_path": os.path.abspath(sample_path), + "QC_Error": str(e) + } -@qc.command() -def run_qc(): +@cli.command() +@click.option('--ref', type=click.Path(exists=True), required=True, help='Reference genome signature file (required).') +@click.option('--sample', type=click.Path(exists=True), multiple=True, help='Sample signature file. Can be provided multiple times.') +@click.option('--samples-from-file', type=click.Path(exists=True), help='File containing sample paths (one per line).') +@click.option('--amplicon', type=click.Path(exists=True), help='Amplicon signature file (optional).') +@click.option('--roi', is_flag=True, default=False, help='Calculate ROI for 1,2,5,9 folds.') +@click.option('--cores', '-c', default=4, type=int, show_default=True, help='Number of CPU cores to use for parallel processing.') +@click.option('--advanced', is_flag=True, default=False, help='Include advanced QC metrics.') +@click.option('--debug', is_flag=True, default=False, help='Enable debugging and detailed logging.') +@click.option('-o', '--output', required=True, callback=validate_tsv_file, help='Output TSV file for QC results.') +def qc(ref: str, sample: List[str], samples_from_file: Optional[str], + amplicon: Optional[str], roi: bool, cores: int, advanced: bool, + debug: bool, output: str): """ - Run quality control checks. + Perform quality control (QC) on multiple samples against a reference genome. + + This command calculates various QC metrics for each provided sample, optionally including advanced metrics + and ROI predictions. Results are aggregated and exported to a TSV file. """ - click.echo('Quality control functionality is not yet implemented.') - # Future implementation goes here + start_time = time.time() + + # Configure logging + logger = logging.getLogger('snipe_qc') + logger.setLevel(logging.DEBUG if debug else logging.INFO) + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if debug else logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + if not logger.hasHandlers(): + logger.addHandler(handler) + + logger.info("Starting QC process.") + + # Collect sample paths from --sample and --samples-from-file + samples_set: Set[str] = set(sample) # Start with samples provided via --sample + + if samples_from_file: + logger.debug(f"Reading samples from file: {samples_from_file}") + try: + with open(samples_from_file, encoding='utf-8') as f: + file_samples = {line.strip() for line in f if line.strip()} + samples_set.update(file_samples) + logger.debug(f"Collected {len(file_samples)} samples from file.") + except Exception as e: + logger.error(f"Failed to read samples from file {samples_from_file}: {e}") + sys.exit(1) + + # Deduplicate and validate sample paths + valid_samples = [] + for sample_path in samples_set: + if os.path.exists(sample_path): + valid_samples.append(os.path.abspath(sample_path)) + else: + logger.warning(f"Sample file does not exist and will be skipped: {sample_path}") + + if not valid_samples: + logger.error("No valid samples provided for QC.") + sys.exit(1) + + logger.info(f"Total valid samples to process: {len(valid_samples)}") + + # Load reference signature + logger.info(f"Loading reference signature from: {ref}") + try: + reference_sig = SnipeSig(sourmash_sig=ref, sig_type=SigType.GENOME, enable_logging=debug) + logger.debug(f"Loaded reference signature: {reference_sig.name}") + except Exception as e: + logger.error(f"Failed to load reference signature from {ref}: {e}") + sys.exit(1) + + # Load amplicon signature if provided + amplicon_sig = None + if amplicon: + logger.info(f"Loading amplicon signature from: {amplicon}") + try: + amplicon_sig = SnipeSig(sourmash_sig=amplicon, sig_type=SigType.AMPLICON, enable_logging=debug) + logger.debug(f"Loaded amplicon signature: {amplicon_sig.name}") + except Exception as e: + logger.error(f"Failed to load amplicon signature from {amplicon}: {e}") + sys.exit(1) + + # Prepare arguments for parallel processing + process_args = [ + (sample_path, ref, amplicon, advanced, roi, debug) + for sample_path in valid_samples + ] + + # Process samples in parallel with progress bar + results = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor: + # Submit all tasks + futures = { + executor.submit(process_sample, *args): args[0] for args in process_args + } + # Iterate over completed futures with a progress bar + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing samples"): + sample = futures[future] + try: + result = future.result() + results.append(result) + except Exception as exc: + logger.error(f"Sample {sample} generated an exception: {exc}") + results.append({ + "sample": os.path.splitext(os.path.basename(sample))[0], + "file_path": sample, + "QC_Error": str(exc) + }) + + # Create pandas DataFrame + logger.info("Aggregating results into DataFrame.") + df = pd.DataFrame(results) + + # Reorder columns to have 'sample' and 'file_path' first, if they exist + cols = list(df.columns) + reordered_cols = [] + for col in ['sample', 'file_path']: + if col in cols: + reordered_cols.append(col) + cols.remove(col) + reordered_cols += cols + df = df[reordered_cols] + + # Export to TSV + try: + df.to_csv(output, sep='\t', index=False) + logger.info(f"QC results successfully exported to {output}") + except Exception as e: + logger.error(f"Failed to export QC results to {output}: {e}") + sys.exit(1) + + end_time = time.time() + elapsed_time = end_time - start_time + logger.info(f"QC process completed in {elapsed_time:.2f} seconds.") if __name__ == '__main__': From 4f216f131393e12820781a8300bd2972e07ce181 Mon Sep 17 00:00:00 2001 From: Mohamed Abuelanin Date: Mon, 14 Oct 2024 11:12:42 -0700 Subject: [PATCH 4/7] separating SnipeSig --- src/snipe/api/snipe_sig.py | 1459 ++++++++++++++++++++++++++++++++++++ 1 file changed, 1459 insertions(+) create mode 100644 src/snipe/api/snipe_sig.py diff --git a/src/snipe/api/snipe_sig.py b/src/snipe/api/snipe_sig.py new file mode 100644 index 0000000..0b5ee64 --- /dev/null +++ b/src/snipe/api/snipe_sig.py @@ -0,0 +1,1459 @@ +import heapq +import logging +from snipe.api.enums import SigType +from typing import Any, Dict, Iterator, List, Optional, Union +import numpy as np +import sourmash + + +# Configure the root logger to CRITICAL to suppress unwanted logs by default +logging.basicConfig(level=logging.CRITICAL) + + +class SnipeSig: + """ + A class to handle Sourmash signatures with additional functionalities + such as customized set operations and abundance management. + """ + + def _try_load_from_json(self, sourmash_sig: str) -> Union[List[sourmash.signature.SourmashSignature], None]: + r""" + Attempt to load sourmash signature from JSON string. + + Parameters: + sourmash_sig (str): JSON string representing a sourmash signature. + + Returns: + sourmash.signature.SourmashSignature or None if loading fails. + """ + try: + self.logger.debug("Trying to load sourmash signature from JSON.") + list_of_sigs = list(sourmash.load_signatures_from_json(sourmash_sig)) + return {sig.name: sig for sig in list_of_sigs} + except Exception as e: + self.logger.debug("Loading from JSON failed. Proceeding to file loading.", exc_info=e) + return None # Return None to indicate failure + + def _try_load_from_file(self, sourmash_sig_path: str) -> Union[List[sourmash.signature.SourmashSignature], None]: + r""" + Attempt to load sourmash signature(s) from a file. + + Parameters: + sourmash_sig_path (str): File path to a sourmash signature. + + Returns: + sourmash.signature.SourmashßSignature, list of sourmash.signature.SourmashSignature, or None if loading fails. + """ + self.logger.debug("Trying to load sourmash signature from file.") + try: + signatures = list(sourmash.load_file_as_signatures(sourmash_sig_path)) + self.logger.debug("Loaded %d sourmash signature(s) from file.", len(signatures)) + sigs_dict = {_sig.name: _sig for _sig in signatures} + self.logger.debug("Loaded sourmash signatures into sigs_dict: %s", sigs_dict) + return sigs_dict + except Exception as e: + self.logger.exception("Failed to load the sourmash signature from the file.", exc_info=e) + raise ValueError("An unexpected error occurred while loading the sourmash signature.") from e + + + def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature], + ksize: int = 51, scale: int = 10000, sig_type=SigType.SAMPLE, enable_logging: bool = False, **kwargs): + r""" + Initialize the SnipeSig with a sourmash signature object or a path to a signature. + + Parameters: + sourmash_sig (str or sourmash.signature.SourmashSignature): A path to a signature file or a signature object. + ksize (int): K-mer size. + scale (int): Scale value. + sig_type (SigType): Type of the signature. + enable_logging (bool): Flag to enable detailed logging. + **kwargs: Additional keyword arguments. + """ + # Initialize logging based on the flag + self.logger = logging.getLogger(self.__class__.__name__) + + # Configure the logger + if enable_logging: + self.logger.setLevel(logging.DEBUG) + if not self.logger.hasHandlers(): + # Create console handler + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # Add formatter to handler + ch.setFormatter(formatter) + # Add handler to logger + self.logger.addHandler(ch) + self.logger.debug("Logging is enabled for SnipeSig.") + else: + self.logger.setLevel(logging.CRITICAL) + + # Initialize internal variables + self.logger.debug("Initializing SnipeSig with sourmash_sig: %s", sourmash_sig) + + self._scale = scale + self._ksize = ksize + self._md5sum = None + self._hashes = np.array([], dtype=np.uint64) + self._abundances = np.array([], dtype=np.uint32) + self._type = sig_type + self._name = None + self._filename = None + self._track_abundance = False + + sourmash_sigs: Dict[str, sourmash.signature.SourmashSignature] = {} + _sourmash_sig: Union[sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature] = None + + + self.logger.debug("Proceeding with a sigtype of %s", sig_type) + + + + + + + if not isinstance(sourmash_sig, (str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature)): + # if the str is not a file path + self.logger.error("Invalid type for sourmash_sig: %s", type(sourmash_sig).__name__) + raise TypeError(f"sourmash_sig must be a file path, sourmash.signature.SourmashSignature, or Frozensourmash_signature, got {type(sourmash_sig).__name__}") + + # Case 1: If sourmash_sig is already a valid sourmash signature object + if isinstance(sourmash_sig, (sourmash.signature.FrozenSourmashSignature, sourmash.signature.SourmashSignature)): + self.logger.debug("Loaded sourmash signature directly from object.") + sourmash_sigs = {sourmash_sig.name: sourmash_sig} + + # Case 2: If sourmash_sig is a string, try to load as JSON or a file + elif isinstance(sourmash_sig, str): + self.logger.debug("Attempting to load sourmash signature from string input.") + + # First, try loading from JSON + sourmash_sigs = self._try_load_from_json(sourmash_sig) + self.logger.debug("Loaded sourmash signature from JSON: %s", sourmash_sigs) + + # If JSON loading fails, try loading from file + if not sourmash_sigs: + sourmash_sigs = self._try_load_from_file(sourmash_sig) + + # If both attempts fail, raise an error + if not sourmash_sigs: + self.logger.error("Failed to load sourmash signature from the provided string.") + raise ValueError("An unexpected error occurred while loading the sourmash signature.") + + if sig_type == SigType.SAMPLE or sig_type == SigType.AMPLICON: + if len(sourmash_sigs) > 1: + self.logger.debug("Multiple signatures found in the input. Expected a single sample signature.") + # not supported at this time + raise ValueError("Loading multiple sample signatures is not supported at this time.") + elif len(sourmash_sigs) == 1: + self.logger.debug("Found a single signature in the sample sig input; Will use this signature.") + _sourmash_sig = list(sourmash_sigs.values())[0] + else: + self.logger.debug("No signature found in the input. Expected a single sample signature.") + raise ValueError("No signature found in the input. Expected a single sample signature.") + + elif sig_type == SigType.GENOME: + if len(sourmash_sigs) > 1: + for signame, sig in sourmash_sigs.items(): + if signame.endswith("-snipegenome"): + sig = sig.to_mutable() + sig.name = sig.name.replace("-snipegenome", "") + self.logger.debug("Found a genome signature with a snipe modified name. Restoring original name `%s`.", sig.name) + _sourmash_sig = sig + break + else: + self.logger.debug("Found multiple signature per the genome file, but none with a snipe modified name.") + raise ValueError("Found multiple signature per the genome file, but none with a snipe modified name.") + elif len(sourmash_sigs) == 1: + self.logger.debug("Found a single signature in the genome sig input; Will use this signature.") + _sourmash_sig = list(sourmash_sigs.values())[0] + else: + self.logger.debug("Unknown sigtype: %s", sig_type) + raise ValueError(f"Unknown sigtype: {sig_type}") + + self.logger.debug("Length of currently loaded signature: %d, with name: %s", len(_sourmash_sig), _sourmash_sig.name) + + # Extract properties from the loaded signature + self._ksize = _sourmash_sig.minhash.ksize + self._scale = _sourmash_sig.minhash.scaled + self._md5sum = _sourmash_sig.md5sum() + self._name = _sourmash_sig.name + self._filename = _sourmash_sig.filename + self._track_abundance = _sourmash_sig.minhash.track_abundance + + # If the signature does not track abundance, assume abundance of 1 for all hashes + if not self._track_abundance: + self.logger.debug("Signature does not track abundance. Setting all abundances to 1.") + self._abundances = np.ones(len(_sourmash_sig.minhash.hashes), dtype=np.uint32) + # self._track_abundance = True + else: + self._abundances = np.array(list(_sourmash_sig.minhash.hashes.values()), dtype=np.uint32) + + self._hashes = np.array(list(_sourmash_sig.minhash.hashes.keys()), dtype=np.uint64) + + # Sort the hashes and rearrange abundances accordingly + sorted_indices = np.argsort(self._hashes) + self._hashes = self._hashes[sorted_indices] + self._abundances = self._abundances[sorted_indices] + + self.logger.debug( + "Loaded sourmash signature from file: %s, name: %s, md5sum: %s, ksize: %d, scale: %d, " + "track_abundance: %s, type: %s, length: %d", + self._filename, self._name, self._md5sum, self._ksize, self._scale, + self._track_abundance, self._type, len(self._hashes) + ) + self.logger.debug("Hashes sorted during initialization.") + self.logger.debug("Sourmash signature loading completed successfully.") + + # Setters and getters + @property + def hashes(self) -> np.ndarray: + r"""Return a copy of the hashes array.""" + return self._hashes.view() + + @property + def abundances(self) -> np.ndarray: + r"""Return a copy of the abundances array.""" + return self._abundances.view() + + @property + def md5sum(self) -> str: + r"""Return the MD5 checksum of the signature.""" + return self._md5sum + + @property + def ksize(self) -> int: + r"""Return the k-mer size.""" + return self._ksize + + @property + def scale(self) -> int: + r"""Return the scale value.""" + return self._scale + + @property + def name(self) -> str: + r"""Return the name of the signature.""" + return self._name + + @property + def filename(self) -> str: + r"""Return the filename of the signature.""" + return self._filename + + @property + def sigtype(self) -> SigType: + r"""Return the type of the signature.""" + return self._type + + @property + def track_abundance(self) -> bool: + r"""Return whether the signature tracks abundance.""" + return self._track_abundance + + # Basic class methods + def get_name(self) -> str: + r"""Get the name of the signature.""" + return self._name + + # setter sigtype + @sigtype.setter + def sigtype(self, sigtype: SigType): + r""" + Set the type of the signature. + """ + self._type = sigtype + + def get_info(self) -> dict: + r""" + Get information about the signature. + + Returns: + dict: A dictionary containing signature information. + """ + info = { + "name": self._name, + "filename": self._filename, + "md5sum": self._md5sum, + "ksize": self._ksize, + "scale": self._scale, + "track_abundance": self._track_abundance, + "sigtype": self._type, + "num_hashes": len(self._hashes) + } + return info + + def __len__(self) -> int: + r"""Return the number of hashes in the signature.""" + return len(self._hashes) + + def __iter__(self) -> Iterator[tuple]: + r""" + Iterate over the hashes and their abundances. + + Yields: + tuple: A tuple containing (hash, abundance). + """ + for h, a in zip(self._hashes, self._abundances): + yield (h, a) + + def __contains__(self, hash_value: int) -> bool: + r""" + Check if a hash is present in the signature. + + Parameters: + hash_value (int): The hash value to check. + + Returns: + bool: True if the hash is present, False otherwise. + """ + # Utilize binary search since hashes are sorted + index = np.searchsorted(self._hashes, hash_value) + if index < len(self._hashes) and self._hashes[index] == hash_value: + return True + return False + + def __repr__(self) -> str: + return (f"SnipeSig(name={self._name}, ksize={self._ksize}, scale={self._scale}, " + f"type={self._type}, num_hashes={len(self._hashes)})") + + def __str__(self) -> str: + return self.__repr__() + + def __verify_snipe_signature(self, other: 'SnipeSig'): + r""" + Verify that the other object is a SnipeSig instance. + + Parameters: + other (SnipeSig): The other signature to verify. + + Raises: + ValueError: If the other object is not a SnipeSig instance. + """ + if not isinstance(other, SnipeSig): + msg = f"Provided sig ({type(other).__name__}) is not a SnipeSig instance." + self.logger.error(msg) + raise ValueError(msg) + + def __verify_matching_ksize_scale(self, other: 'SnipeSig'): + r""" + Verify that the ksize and scale match between two signatures. + + Parameters: + other (SnipeSig): The other signature to compare. + + Raises: + ValueError: If ksize or scale do not match. + """ + if self._ksize != other.ksize: + _e_msg = f"K-mer size does not match between the two signatures: {self._ksize} vs {other.ksize}." + self.logger.error(_e_msg) + raise ValueError(_e_msg) + if self._scale != other.scale: + _e_msg = f"Scale value does not match between the two signatures: {self._scale} vs {other.scale}." + self.logger.error(_e_msg) + raise ValueError(_e_msg) + + def _validate_abundance_operation(self, value: Union[int, None], operation: str): + r""" + Validate that the signature tracks abundance and that the provided value is a non-negative integer. + + Parameters: + value (int or None): The abundance value to validate. Can be None for operations that don't require a value. + operation (str): Description of the operation for logging purposes. + + Raises: + ValueError: If the signature does not track abundance or if the value is invalid. + """ + if not self._track_abundance and self.sigtype == SigType.SAMPLE: + self.logger.error("Cannot %s: signature does not track abundance.", operation) + raise ValueError("Signature does not track abundance.") + + if value is not None: + if not isinstance(value, int) or value < 0: + self.logger.error("%s requires a non-negative integer value.", operation.capitalize()) + raise ValueError(f"{operation.capitalize()} requires a non-negative integer value.") + + # Mask application method + def _apply_mask(self, mask: np.ndarray): + r""" + Apply a boolean mask to the hashes and abundances arrays. + Ensures that the sorted order is preserved. + + Parameters: + mask (np.ndarray): Boolean array indicating which elements to keep. + """ + self._hashes = self._hashes[mask] + self._abundances = self._abundances[mask] + + # Verify that the hashes remain sorted + if self._hashes.size > 1: + if not np.all(self._hashes[:-1] <= self._hashes[1:]): + self.logger.error("Hashes are not sorted after applying mask.") + raise RuntimeError("Hashes are not sorted after applying mask.") + self.logger.debug("Applied mask. Hashes remain sorted.") + + # Set operation methods + def union_sigs(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Combine this signature with another by summing abundances where hashes overlap. + + Given two signatures \( A \) and \( B \) with hash sets \( H_A \) and \( H_B \), + and their corresponding abundance functions \( a_A \) and \( a_B \), the union + signature \( C \) is defined as follows: + + - **Hash Set**: + + $$ + H_C = H_A \cup H_B + $$ + + - **Abundance Function**: + + $$ + a_C(h) = + \begin{cases} + a_A(h) + a_B(h), & \text{if } h \in H_A \cap H_B \\ + a_A(h), & \text{if } h \in H_A \setminus H_B \\ + a_B(h), & \text{if } h \in H_B \setminus H_A + \end{cases} + $$ + """ + self.__verify_snipe_signature(other) + self.__verify_matching_ksize_scale(other) + + self.logger.debug("Unioning signatures (including all unique hashes).") + + # Access internal arrays directly + self_hashes = self._hashes + self_abundances = self._abundances + other_hashes = other._hashes + other_abundances = other._abundances + + # Handle the case where 'other' does not track abundance + if not other.track_abundance: + self.logger.debug("Other signature does not track abundance. Setting abundances to 1.") + other_abundances = np.ones_like(other_abundances, dtype=np.uint32) + + # Combine hashes and abundances + combined_hashes = np.concatenate((self_hashes, other_hashes)) + combined_abundances = np.concatenate((self_abundances, other_abundances)) + + # Use numpy's unique function with return_inverse to sum abundances efficiently + unique_hashes, inverse_indices = np.unique(combined_hashes, return_inverse=True) + summed_abundances = np.zeros_like(unique_hashes, dtype=np.uint32) + + # Sum abundances for duplicate hashes + np.add.at(summed_abundances, inverse_indices, combined_abundances) + + # Handle potential overflow + summed_abundances = np.minimum(summed_abundances, np.iinfo(np.uint32).max) + + self.logger.debug("Union operation completed. Total hashes: %d", len(unique_hashes)) + + # Create a new SnipeSig instance + return self.create_from_hashes_abundances( + hashes=unique_hashes, + abundances=summed_abundances, + ksize=self._ksize, + scale=self._scale, + name=f"{self._name}_union_{other._name}", + filename=None, + enable_logging=self.logger.level <= logging.DEBUG + ) + + def _convert_to_sourmash_signature(self): + r""" + Convert the SnipeSig instance to a sourmash.signature.SourmashSignature object. + + Returns: + sourmash.signature.SourmashSignature: A new sourmash.signature.SourmashSignature instance. + """ + self.logger.debug("Converting SnipeSig to sourmash.signature.SourmashSignature.") + + mh = sourmash.minhash.MinHash(n=0, ksize=self._ksize, scaled=self._scale, track_abundance=self._track_abundance) + mh.set_abundances(dict(zip(self._hashes, self._abundances))) + self.sourmash_sig = sourmash.signature.SourmashSignature(mh, name=self._name, filename=self._filename) + self.logger.debug("Conversion to sourmash.signature.SourmashSignature completed.") + + def export(self, path) -> None: + r""" + Export the signature to a file. + + Parameters: + path (str): The path to save the signature to. + """ + self._convert_to_sourmash_signature() + with open(str(path), "wb") as fp: + sourmash.signature.save_signatures_to_json([self.sourmash_sig], fp) + + def export_to_string(self): + r""" + Export the signature to a JSON string. + + Returns: + str: JSON string representation of the signature. + """ + self._convert_to_sourmash_signature() + return sourmash.signature.save_signatures_to_json([self.sourmash_sig]).decode('utf-8') + + def intersection_sigs(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Compute the intersection of the current signature with another signature. + + This method keeps only the hashes that are common to both signatures, and retains the abundances from self. + + **Mathematical Explanation**: + + Let \( A \) and \( B \) be two signatures with sets of hashes \( H_A \) and \( H_B \), + and abundance functions \( a_A(h) \) and \( a_B(h) \), the intersection signature \( C \) has: + + - Hash set: + $$ + H_C = H_A \cap H_B + $$ + + - Abundance function: + $$ + a_C(h) = a_A(h), \quad \text{for } h \in H_C + $$ + + **Parameters**: + - `other (SnipeSig)`: Another `SnipeSig` instance to intersect with. + + **Returns**: + - `SnipeSig`: A new `SnipeSig` instance representing the intersection of the two signatures. + + **Raises**: + - `ValueError`: If `ksize` or `scale` do not match between signatures. + """ + self.__verify_snipe_signature(other) + self.__verify_matching_ksize_scale(other) + + self.logger.debug("Intersecting signatures.") + + # Use numpy's intersect1d function + common_hashes, self_indices, _ = np.intersect1d( + self._hashes, other._hashes, assume_unique=True, return_indices=True + ) + + if common_hashes.size == 0: + self.logger.debug("No common hashes found. Returning an empty signature.") + return self.create_from_hashes_abundances( + hashes=np.array([], dtype=np.uint64), + abundances=np.array([], dtype=np.uint32), + ksize=self._ksize, + scale=self._scale, + name=f"{self._name}_intersection_{other._name}", + filename=None, + enable_logging=self.logger.level <= logging.DEBUG + ) + + # Get the abundances from self + common_abundances = self._abundances[self_indices] + + self.logger.debug("Intersection operation completed. Total common hashes: %d", len(common_hashes)) + + # Create a new SnipeSig instance + return self.create_from_hashes_abundances( + hashes=common_hashes, + abundances=common_abundances, + ksize=self._ksize, + scale=self._scale, + name=f"{self._name}_intersection_{other._name}", + filename=None, + enable_logging=self.logger.level <= logging.DEBUG + ) + + def difference_sigs(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Compute the difference of the current signature with another signature. + + This method removes hashes that are present in the other signature from self, + keeping the abundances from self. + + **Mathematical Explanation**: + + Let \( A \) and \( B \) be two signatures with sets of hashes \( H_A \) and \( H_B \), + and abundance function \( a_A(h) \), the difference signature \( C \) has: + + - Hash set: + $$ + H_C = H_A \setminus H_B + $$ + + - Abundance function: + $$ + a_C(h) = a_A(h), \quad \text{for } h \in H_C + $$ + + **Parameters**: + - `other (SnipeSig)`: Another `SnipeSig` instance to subtract from the current signature. + + **Returns**: + - `SnipeSig`: A new `SnipeSig` instance representing the difference of the two signatures. + + **Raises**: + - `ValueError`: If `ksize` or `scale` do not match between signatures. + - `RuntimeError`: If zero hashes remain after difference. + """ + self.__verify_snipe_signature(other) + self.__verify_matching_ksize_scale(other) + + self.logger.debug("Differencing signatures.") + + # Use numpy's setdiff1d function + diff_hashes = np.setdiff1d(self._hashes, other._hashes, assume_unique=True) + + if diff_hashes.size == 0: + _e_msg = f"Difference operation resulted in zero hashes, which is not allowed for {self._name} and {other._name}." + self.logger.warning(_e_msg) + + # Get the indices of the hashes in self + mask = np.isin(self._hashes, diff_hashes, assume_unique=True) + diff_abundances = self._abundances[mask] + + self.logger.debug("Difference operation completed. Remaining hashes: %d", len(diff_hashes)) + + # Create a new SnipeSig instance + return self.create_from_hashes_abundances( + hashes=diff_hashes, + abundances=diff_abundances, + ksize=self._ksize, + scale=self._scale, + name=f"{self._name}_difference_{other._name}", + filename=None, + enable_logging=self.logger.level <= logging.DEBUG + ) + + def symmetric_difference_sigs(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Compute the symmetric difference of the current signature with another signature. + + This method retains hashes that are unique to each signature, with their respective abundances. + + **Mathematical Explanation**: + + Let \( A \) and \( B \) be two signatures with sets of hashes \( H_A \) and \( H_B \), + and abundance functions \( a_A(h) \) and \( a_B(h) \), the symmetric difference signature \( C \) has: + + - Hash set: + $$ + H_C = (H_A \setminus H_B) \cup (H_B \setminus H_A) + $$ + + - Abundance function: + $$ + a_C(h) = + \begin{cases} + a_A(h), & \text{for } h \in H_A \setminus H_B \\ + a_B(h), & \text{for } h \in H_B \setminus H_A \\ + \end{cases} + $$ + + **Parameters**: + - `other (SnipeSig)`: Another `SnipeSig` instance to compute the symmetric difference with. + + **Returns**: + - `SnipeSig`: A new `SnipeSig` instance representing the symmetric difference of the two signatures. + + **Raises**: + - `ValueError`: If `ksize` or `scale` do not match between signatures. + - `RuntimeError`: If zero hashes remain after symmetric difference. + """ + self.__verify_snipe_signature(other) + self.__verify_matching_ksize_scale(other) + + self.logger.debug("Computing symmetric difference of signatures.") + + # Hashes unique to self and other + unique_self_hashes = np.setdiff1d(self._hashes, other._hashes, assume_unique=True) + unique_other_hashes = np.setdiff1d(other._hashes, self._hashes, assume_unique=True) + + # Abundances for unique hashes + mask_self = np.isin(self._hashes, unique_self_hashes, assume_unique=True) + unique_self_abundances = self._abundances[mask_self] + + mask_other = np.isin(other._hashes, unique_other_hashes, assume_unique=True) + unique_other_abundances = other._abundances[mask_other] + + # Handle the case where 'other' does not track abundance + if not other.track_abundance: + self.logger.debug("Other signature does not track abundance. Setting abundances to 1.") + unique_other_abundances = np.ones_like(unique_other_abundances, dtype=np.uint32) + + # Combine hashes and abundances + combined_hashes = np.concatenate((unique_self_hashes, unique_other_hashes)) + combined_abundances = np.concatenate((unique_self_abundances, unique_other_abundances)) + + if combined_hashes.size == 0: + _e_msg = "Symmetric difference operation resulted in zero hashes, which is not allowed." + self.logger.error(_e_msg) + raise RuntimeError(_e_msg) + + # Sort combined hashes and abundances + sorted_indices = np.argsort(combined_hashes) + combined_hashes = combined_hashes[sorted_indices] + combined_abundances = combined_abundances[sorted_indices] + + self.logger.debug("Symmetric difference operation completed. Total unique hashes: %d", len(combined_hashes)) + + # Create a new SnipeSig instance + return self.create_from_hashes_abundances( + hashes=combined_hashes, + abundances=combined_abundances, + ksize=self._ksize, + scale=self._scale, + name=f"{self._name}_symmetric_difference_{other._name}", + filename=None, + enable_logging=self.logger.level <= logging.DEBUG + ) + + # Magic methods for union operations + def __add__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the + operator. + Includes all unique hashes from both signatures and sums their abundances where hashes overlap, + returning a new signature. + + Returns: + SnipeSig: Union of self and other. + """ + return self.union_sigs(other) + + def __iadd__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the += operator. + Includes all unique hashes from both signatures and sums their abundances where hashes overlap, + modifying self in-place. + + Returns: + SnipeSig: Updated self after addition. + """ + union_sig = self.union_sigs(other) + self._update_from_union(union_sig) + return self + + def __or__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the | operator. + Includes all unique hashes from both signatures and sums their abundances where hashes overlap, + returning a new signature. + + Returns: + SnipeSig: Union of self and other. + """ + return self.union_sigs(other) + + def __ior__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the |= operator. + Includes all unique hashes from both signatures and sums their abundances where hashes overlap, + modifying self in-place. + + Returns: + SnipeSig: Updated self after union. + """ + union_sig = self.union_sigs(other) + self._update_from_union(union_sig) + return self + + def __sub__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the - operator. + Removes hashes present in other from self, keeping abundances from self, + returning a new signature. + + Returns: + SnipeSig: Difference of self and other. + """ + return self.difference_sigs(other) + + def __isub__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the -= operator. + Removes hashes present in other from self, keeping abundances from self, + modifying self in-place. + + Returns: + SnipeSig: Updated self after difference. + + Raises: + RuntimeError: If zero hashes remain after difference. + """ + difference_sig = self.difference_sigs(other) + self._update_from_union(difference_sig) + return self + + def __xor__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the ^ operator. + Keeps unique hashes from each signature with their respective abundances, returning a new signature. + + Returns: + SnipeSig: Symmetric difference of self and other. + """ + return self.symmetric_difference_sigs(other) + + def __ixor__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the ^= operator. + Keeps unique hashes from each signature with their respective abundances, modifying self in-place. + + Returns: + SnipeSig: Updated self after symmetric difference. + + Raises: + RuntimeError: If zero hashes remain after symmetric difference. + """ + symmetric_diff_sig = self.symmetric_difference_sigs(other) + self._update_from_union(symmetric_diff_sig) + return self + + def __and__(self, other: 'SnipeSig') -> 'SnipeSig': + r""" + Implements the & operator. + Keeps common hashes and retains abundances from self only, returning a new signature. + + Returns: + SnipeSig: Intersection of self and other. + """ + return self.intersection_sigs(other) + + def _update_from_union(self, other: 'SnipeSig'): + r""" + Update self's hashes and abundances from another SnipeSig instance. + + Parameters: + other (SnipeSig): The other SnipeSig instance to update from. + """ + self._hashes = other.hashes + self._abundances = other.abundances + self._name = other.name + self._filename = other.filename + self._md5sum = other.md5sum + self._track_abundance = other.track_abundance + # No need to update ksize and scale since they are verified to match + + @classmethod + def create_from_hashes_abundances(cls, hashes: np.ndarray, abundances: np.ndarray, + ksize: int, scale: int, name: str = None, + filename: str = None, enable_logging: bool = False, sig_type: SigType = SigType.SAMPLE) -> 'SnipeSig': + """ + Internal method to create a SnipeSig instance from hashes and abundances. + + Parameters: + hashes (np.ndarray): Array of hash values. + abundances (np.ndarray): Array of abundance values corresponding to the hashes. + ksize (int): K-mer size. + scale (int): Scale value. + name (str): Optional name for the signature. + filename (str): Optional filename for the signature. + sig_type (SigType): Type of the signature. + enable_logging (bool): Flag to enable logging. + + Returns: + SnipeSig: A new SnipeSig instance. + """ + # Create a mock sourmash signature object + mh = sourmash.minhash.MinHash(n=0, ksize=ksize, scaled=scale, track_abundance=True) + mh.set_abundances(dict(zip(hashes, abundances))) + sig = sourmash.signature.SourmashSignature(mh, name=name or "", filename=filename or "") + return cls(sourmash_sig=sig, sig_type=sig_type, enable_logging=enable_logging) + + # Aggregation Operations + @classmethod + def sum_signatures(cls, signatures: List['SnipeSig'], name: str = "summed_signature", + filename: str = None, enable_logging: bool = False) -> 'SnipeSig': + + r""" + Sum multiple SnipeSig instances by including all unique hashes and summing their abundances where hashes overlap. + This method utilizes a heap-based multi-way merge algorithm for enhanced efficiency when handling thousands of signatures. + + $$ + \text{Sum}(A_1, A_2, \dots, A_n) = \bigcup_{i=1}^{n} A_i + $$ + + For each hash \( h \), its total abundance is: + $$ + \text{abundance}(h) = \sum_{i=1}^{n} \text{abundance}_i(h) + $$ + + **Mathematical Explanation**: + + - **Union of Signatures**: + The summation of signatures involves creating a union of all unique k-mers (hashes) present across the input signatures. + + - **Total Abundance Calculation**: + For each unique hash \( h \), the total abundance is the sum of its abundances across all signatures where it appears. + + - **Algorithm Efficiency**: + By using a min-heap to perform a multi-way merge of sorted hash arrays, the method ensures that each hash is processed in ascending order without the need to store all hashes in memory simultaneously. + + **Parameters**: + - `signatures (List[SnipeSig])`: List of `SnipeSig` instances to sum. + - `name (str)`: Optional name for the resulting signature. + - `filename (str)`: Optional filename for the resulting signature. + - `enable_logging (bool)`: Flag to enable detailed logging. + + **Returns**: + - `SnipeSig`: A new `SnipeSig` instance representing the sum of the signatures. + + **Raises**: + - `ValueError`: If the signatures list is empty or if `ksize`/`scale` do not match across signatures. + - `RuntimeError`: If an error occurs during the summation process. + """ + if not signatures: + raise ValueError("No signatures provided for summation.") + + # Verify that all signatures have the same ksize, scale, and track_abundance + first_sig = signatures[0] + ksize = first_sig.ksize + scale = first_sig.scale + track_abundance = first_sig.track_abundance + + for sig in signatures[1:]: + if sig.ksize != ksize or sig.scale != scale: + raise ValueError("All signatures must have the same ksize and scale.") + if sig.track_abundance != track_abundance: + raise ValueError("All signatures must have the same track_abundance setting.") + + # Initialize iterators for each signature's hashes and abundances + iterators = [] + for sig in signatures: + it = iter(zip(sig.hashes, sig.abundances)) + try: + first_hash, first_abundance = next(it) + iterators.append((first_hash, first_abundance, it)) + except StopIteration: + continue # Skip empty signatures + + if not iterators: + raise ValueError("All provided signatures are empty.") + + # Initialize the heap with the first element from each iterator + heap = [] + for idx, (hash_val, abundance, it) in enumerate(iterators): + heap.append((hash_val, abundance, idx)) + heapq.heapify(heap) + + # Prepare lists to collect the summed hashes and abundances + summed_hashes = [] + summed_abundances = [] + + while heap: + current_hash, current_abundance, idx = heapq.heappop(heap) + # Initialize total abundance for the current_hash + total_abundance = current_abundance + + # Check if the next element in the heap has the same hash + while heap and heap[0][0] == current_hash: + _, abundance, same_idx = heapq.heappop(heap) + total_abundance += abundance + # Push the next element from the same iterator + try: + next_hash, next_abundance = next(iterators[same_idx][2]) + heapq.heappush(heap, (next_hash, next_abundance, same_idx)) + except StopIteration: + pass # No more elements in this iterator + + # Append the summed hash and abundance + summed_hashes.append(current_hash) + summed_abundances.append(total_abundance) + + # Push the next element from the current iterator + try: + next_hash, next_abundance = next(iterators[idx][2]) + heapq.heappush(heap, (next_hash, next_abundance, idx)) + except StopIteration: + pass # No more elements in this iterator + + # Convert the results to NumPy arrays for efficient storage and processing + summed_hashes = np.array(summed_hashes, dtype=np.uint64) + summed_abundances = np.array(summed_abundances, dtype=np.uint32) + + # Handle potential overflow by capping at the maximum value of uint32 + summed_abundances = np.minimum(summed_abundances, np.iinfo(np.uint32).max) + + # Create a new SnipeSig instance from the summed hashes and abundances + summed_signature = cls.create_from_hashes_abundances( + hashes=summed_hashes, + abundances=summed_abundances, + ksize=ksize, + scale=scale, + name=name, + filename=filename, + enable_logging=enable_logging + ) + + return summed_signature + + @staticmethod + def get_unique_signatures(signatures: Dict[str, 'SnipeSig']) -> Dict[str, 'SnipeSig']: + """ + Extract unique signatures from a dictionary of SnipeSig instances. + + For each signature, the unique_sig contains only the hashes that do not overlap with any other signature. + + Parameters: + signatures (Dict[str, SnipeSig]): A dictionary mapping signature names to SnipeSig instances. + + Returns: + Dict[str, SnipeSig]: A dictionary mapping signature names to their unique SnipeSig instances. + + Raises: + ValueError: If the input dictionary is empty or if signatures have mismatched ksize/scale. + """ + if not signatures: + raise ValueError("The input signatures dictionary is empty.") + + # Extract ksize and scale from the first signature + first_name, first_sig = next(iter(signatures.items())) + ksize = first_sig.ksize + scale = first_sig.scale + + # Verify that all signatures have the same ksize and scale + for name, sig in signatures.items(): + if sig.ksize != ksize or sig.scale != scale: + raise ValueError(f"Signature '{name}' has mismatched ksize or scale.") + + # Aggregate all hashes from all signatures + all_hashes = np.concatenate([sig.hashes for sig in signatures.values()]) + + # Count the occurrences of each hash + unique_hashes, counts = np.unique(all_hashes, return_counts=True) + + # Identify hashes that are unique across all signatures (count == 1) + unique_across_all = unique_hashes[counts == 1] + + # Convert to a set for faster membership testing + unique_set = set(unique_across_all) + + unique_signatures = {} + + for name, sig in signatures.items(): + # Find hashes in the current signature that are unique across all signatures + mask_unique = np.isin(sig.hashes, list(unique_set)) + + # Extract unique hashes and their abundances + unique_hashes_sig = sig.hashes[mask_unique] + unique_abundances_sig = sig.abundances[mask_unique] + + # Create a new SnipeSig instance with the unique hashes and abundances + unique_sig = SnipeSig.create_from_hashes_abundances( + hashes=unique_hashes_sig, + abundances=unique_abundances_sig, + ksize=ksize, + scale=scale, + name=f"{name}_unique", + filename=None, + enable_logging=False, # Set to True if you want logging for the new signatures + sig_type=SigType.SAMPLE # Adjust sig_type as needed + ) + + unique_signatures[name] = unique_sig + + return unique_signatures + + + @classmethod + def common_hashes(cls, signatures: List['SnipeSig'], name: str = "common_hashes_signature", + filename: str = None, enable_logging: bool = False) -> 'SnipeSig': + r""" + Compute the intersection of multiple SnipeSig instances, returning a new SnipeSig containing + only the hashes present in all signatures, with abundances set to the minimum abundance across signatures. + + This method uses a heap-based multi-way merge algorithm for efficient computation, + especially when handling a large number of signatures with sorted hashes. + + **Mathematical Explanation**: + + Given signatures \( A_1, A_2, \dots, A_n \) with hash sets \( H_1, H_2, \dots, H_n \), + the intersection signature \( C \) has: + + - Hash set: + $$ + H_C = \bigcap_{i=1}^{n} H_i + $$ + + - Abundance function: + $$ + a_C(h) = \min_{i=1}^{n} a_i(h), \quad \text{for } h \in H_C + $$ + + **Parameters**: + - `signatures (List[SnipeSig])`: List of `SnipeSig` instances to compute the intersection. + - `name (str)`: Optional name for the resulting signature. + - `filename (str)`: Optional filename for the resulting signature. + - `enable_logging (bool)`: Flag to enable detailed logging. + + **Returns**: + - `SnipeSig`: A new `SnipeSig` instance representing the intersection of the signatures. + + **Raises**: + - `ValueError`: If the signatures list is empty or if `ksize`/`scale` do not match across signatures. + """ + if not signatures: + raise ValueError("No signatures provided for intersection.") + + # Verify that all signatures have the same ksize and scale + first_sig = signatures[0] + ksize = first_sig.ksize + scale = first_sig.scale + for sig in signatures[1:]: + if sig.ksize != ksize or sig.scale != scale: + raise ValueError("All signatures must have the same ksize and scale.") + + num_signatures = len(signatures) + iterators = [] + for sig in signatures: + it = iter(zip(sig.hashes, sig.abundances)) + try: + first_hash, first_abundance = next(it) + iterators.append((first_hash, first_abundance, it)) + except StopIteration: + # One of the signatures is empty; intersection is empty + return cls.create_from_hashes_abundances( + hashes=np.array([], dtype=np.uint64), + abundances=np.array([], dtype=np.uint32), + ksize=ksize, + scale=scale, + name=name, + filename=filename, + enable_logging=enable_logging + ) + + # Initialize the heap with the first element from each iterator + heap = [] + for idx, (hash_val, abundance, it) in enumerate(iterators): + heap.append((hash_val, abundance, idx)) + heapq.heapify(heap) + + common_hashes = [] + common_abundances = [] + + while heap: + # Pop all entries with the smallest hash + current_hash, current_abundance, idx = heapq.heappop(heap) + same_hash_entries = [(current_hash, current_abundance, idx)] + + # Collect all entries in the heap that have the same current_hash + while heap and heap[0][0] == current_hash: + h, a, i = heapq.heappop(heap) + same_hash_entries.append((h, a, i)) + + if len(same_hash_entries) == num_signatures: + # The current_hash is present in all signatures + # Take the minimum abundance across signatures + min_abundance = min(entry[1] for entry in same_hash_entries) + common_hashes.append(current_hash) + common_abundances.append(min_abundance) + + # Push the next element from each iterator that had the current_hash + for entry in same_hash_entries: + h, a, i = entry + try: + next_hash, next_abundance = next(iterators[i][2]) + heapq.heappush(heap, (next_hash, next_abundance, i)) + except StopIteration: + pass # Iterator exhausted + + # Convert the results to NumPy arrays + if not common_hashes: + # No common hashes found + unique_hashes = np.array([], dtype=np.uint64) + unique_abundances = np.array([], dtype=np.uint32) + else: + unique_hashes = np.array(common_hashes, dtype=np.uint64) + unique_abundances = np.array(common_abundances, dtype=np.uint32) + + # Create a new SnipeSig instance from the common hashes and abundances + common_signature = cls.create_from_hashes_abundances( + hashes=unique_hashes, + abundances=unique_abundances, + ksize=ksize, + scale=scale, + name=name, + filename=filename, + enable_logging=enable_logging + ) + + return common_signature + + def copy(self) -> 'SnipeSig': + r""" + Create a copy of the current SnipeSig instance. + + Returns: + SnipeSig: A new instance that is a copy of self. + """ + return SnipeSig(sourmash_sig=self.export_to_string(), sig_type=self.sigtype, enable_logging=self.logger.level <= logging.DEBUG) + + # Implement the __radd__ method to support sum() + def __radd__(self, other: Union[int, 'SnipeSig']) -> 'SnipeSig': + r""" + Implements the right-hand + operator to support sum(). + + Returns: + SnipeSig: Union of self and other. + """ + return self.__radd_sum__(other) + + # Override the __sum__ method + def __radd_sum__(self, other: Union[int, 'SnipeSig']) -> 'SnipeSig': + r""" + Internal helper method to support the sum() function. + + Parameters: + other (int or SnipeSig): The other object to add. If other is 0, return self. + + Returns: + SnipeSig: The result of the addition. + """ + if other == 0: + return self + if not isinstance(other, SnipeSig): + raise TypeError(f"Unsupported operand type(s) for +: 'SnipeSig' and '{type(other).__name__}'") + return self.union_sigs(other) + + def reset_abundance(self, new_abundance: int = 1): + r""" + Reset all abundances to a specified value. + + This method sets the abundance of every hash in the signature to the specified `new_abundance` value. + + **Mathematical Explanation**: + + For each hash \( h \) in the signature, the abundance function is updated to: + $$ + a(h) = \text{new\_abundance} + $$ + + **Parameters**: + - `new_abundance (int)`: The new abundance value to set for all hashes. Default is 1. + + **Raises**: + - `ValueError`: If the signature does not track abundance or if `new_abundance` is invalid. + """ + + self._validate_abundance_operation(new_abundance, "reset abundance") + + self._abundances[:] = new_abundance + self.logger.debug("Reset all abundances to %d.", new_abundance) + + def keep_min_abundance(self, min_abundance: int): + r""" + Keep only hashes with abundances greater than or equal to a minimum threshold. + + This method removes hashes whose abundances are less than the specified `min_abundance`. + + **Mathematical Explanation**: + + The updated hash set \( H' \) is: + $$ + H' = \{ h \in H \mid a(h) \geq \text{min\_abundance} \} + $$ + + **Parameters**: + - `min_abundance (int)`: The minimum abundance threshold. + + **Raises**: + - `ValueError`: If the signature does not track abundance or if `min_abundance` is invalid. + """ + self._validate_abundance_operation(min_abundance, "keep minimum abundance") + + mask = self._abundances >= min_abundance + self._apply_mask(mask) + self.logger.debug("Kept hashes with abundance >= %d.", min_abundance) + + def keep_max_abundance(self, max_abundance: int): + r""" + Keep only hashes with abundances less than or equal to a maximum threshold. + + This method removes hashes whose abundances are greater than the specified `max_abundance`. + + **Mathematical Explanation**: + + The updated hash set \( H' \) is: + $$ + H' = \{ h \in H \mid a(h) \leq \text{max\_abundance} \} + $$ + + **Parameters**: + - `max_abundance (int)`: The maximum abundance threshold. + + **Raises**: + - `ValueError`: If the signature does not track abundance or if `max_abundance` is invalid. + """ + self._validate_abundance_operation(max_abundance, "keep maximum abundance") + + mask = self._abundances <= max_abundance + self._apply_mask(mask) + self.logger.debug("Kept hashes with abundance <= %d.", max_abundance) + + def trim_below_median(self): + r""" + Trim hashes with abundances below the median abundance. + + This method removes all hashes whose abundances are less than the median abundance of the signature. + + **Mathematical Explanation**: + + Let \\( m \\) be the median of \\( \\{ a(h) \mid h \in H \\} \\). + The updated hash set \\( H' \\) is: + + $$ + H' = \\{ h \in H \mid a(h) \geq m \\} + $$ + + **Raises**: + - `ValueError`: If the signature does not track abundance. + """ + + self._validate_abundance_operation(None, "trim below median") + + if len(self._abundances) == 0: + self.logger.debug("No hashes to trim based on median abundance.") + return + + median = np.median(self._abundances) + mask = self._abundances >= median + self._apply_mask(mask) + self.logger.debug("Trimmed hashes with abundance below median (%f).", median) + + def count_singletons(self) -> int: + r""" + Return the number of hashes with abundance equal to 1. + + Returns: + int: Number of singletons. + + Raises: + ValueError: If the signature does not track abundance. + """ + self._validate_abundance_operation(None, "count singletons") + + count = np.sum(self._abundances == 1) + self.logger.debug("Number of singletons (abundance == 1): %d", count) + return int(count) + + def trim_singletons(self): + r""" + Remove hashes with abundance equal to 1. + + This method removes all hashes that are singletons (abundance equals 1). + + **Mathematical Explanation**: + + The updated hash set \( H' \) is: + $$ + H' = \{ h \in H \mid a(h) \neq 1 \} + $$ + + **Raises**: + - `ValueError`: If the signature does not track abundance. + """ + self._validate_abundance_operation(None, "trim singletons") + + mask = self._abundances != 1 + self.logger.debug("Trimming %d hashes with abundance equal to 1.", np.sum(~mask)) + self._apply_mask(mask) + self.logger.debug("Size after trimming singletons: %d", len(self._hashes)) + + # New Properties Implemented as per Request + + @property + def total_abundance(self) -> int: + r""" + Return the total abundance (sum of all abundances). + + Returns: + int: Total abundance. + """ + self._validate_abundance_operation(None, "calculate total abundance") + + total = int(np.sum(self._abundances)) + self.logger.debug("Total abundance: %d", total) + return total + + @property + def mean_abundance(self) -> float: + r""" + Return the mean (average) abundance. + + Returns: + float: Mean abundance. + """ + self._validate_abundance_operation(None, "calculate mean abundance") + + if len(self._abundances) == 0: + self.logger.debug("No abundances to calculate mean.") + return 0.0 + + mean = float(np.mean(self._abundances)) # Changed to float + self.logger.debug("Mean abundance: %f", mean) + return mean + + @property + def get_sample_stats(self) -> dict: + r""" + Retrieve statistical information about the signature. + + This property computes and returns a dictionary containing various statistics of the signature, such as total abundance, mean and median abundances, number of singletons, and total number of hashes. + + **Returns**: + - `dict`: A dictionary containing sample statistics: + - `total_abundance`: Sum of abundances. + - `mean_abundance`: Mean abundance. + - `median_abundance`: Median abundance. + - `num_singletons`: Number of hashes with abundance equal to 1. + - `num_hashes`: Total number of hashes. + - `ksize`: K-mer size. + - `scale`: Scale value. + - `name`: Name of the signature. + - `filename`: Filename of the signature. + """ + + # if self.sigtype != SigType.SAMPLE then don't return abundance stats + + stats = { + "num_hashes": len(self._hashes), + "ksize": self._ksize, + "scale": self._scale, + "name": self._name, + "filename": self._filename + } + + if self.sigtype != SigType.SAMPLE: + stats["total_abundance"] = None + stats["mean_abundance"] = None + stats["median_abundance"] = None + stats["num_singletons"] = None + else: + stats["total_abundance"] = self.total_abundance + stats["mean_abundance"] = self.mean_abundance + stats["median_abundance"] = self.median_abundance + stats["num_singletons"] = self.count_singletons() + + return stats + + @property + def median_abundance(self) -> float: + r""" + Return the median abundance. + + Returns: + float: Median abundance. + + Raises: + ValueError: If the signature does not track abundance. + """ + self._validate_abundance_operation(None, "calculate median abundance") + + if len(self._abundances) == 0: + self.logger.debug("No abundances to calculate median.") + return 0.0 + + median = float(np.median(self._abundances)) # Changed to float + self.logger.debug("Median abundance: %f", median) + return median From 4ac09c363abbb2c587d84c3204688a2eeee79754 Mon Sep 17 00:00:00 2001 From: Mohamed Abuelanin Date: Mon, 14 Oct 2024 11:17:50 -0700 Subject: [PATCH 5/7] more testing --- src/snipe/api/__init__.py | 1459 -------------------------------- tests/api/test_api.py | 2 +- tests/api/test_reference_qc.py | 299 ++++++- 3 files changed, 299 insertions(+), 1461 deletions(-) diff --git a/src/snipe/api/__init__.py b/src/snipe/api/__init__.py index 64f915e..e69de29 100644 --- a/src/snipe/api/__init__.py +++ b/src/snipe/api/__init__.py @@ -1,1459 +0,0 @@ -import heapq -import logging -import warnings -from snipe.api.enums import SigType -from typing import Any, Dict, Iterator, List, Optional, Union -import numpy as np -import sourmash - - -# Configure the root logger to CRITICAL to suppress unwanted logs by default -logging.basicConfig(level=logging.CRITICAL) - - -class SnipeSig: - """ - A class to handle Sourmash signatures with additional functionalities - such as customized set operations and abundance management. - """ - - def _try_load_from_json(self, sourmash_sig: str) -> Union[List[sourmash.signature.SourmashSignature], None]: - r""" - Attempt to load sourmash signature from JSON string. - - Parameters: - sourmash_sig (str): JSON string representing a sourmash signature. - - Returns: - sourmash.signature.SourmashSignature or None if loading fails. - """ - try: - self.logger.debug("Trying to load sourmash signature from JSON.") - list_of_sigs = list(sourmash.load_signatures_from_json(sourmash_sig)) - return {sig.name: sig for sig in list_of_sigs} - except Exception as e: - self.logger.debug("Loading from JSON failed. Proceeding to file loading.", exc_info=e) - return None # Return None to indicate failure - - def _try_load_from_file(self, sourmash_sig_path: str) -> Union[List[sourmash.signature.SourmashSignature], None]: - r""" - Attempt to load sourmash signature(s) from a file. - - Parameters: - sourmash_sig_path (str): File path to a sourmash signature. - - Returns: - sourmash.signature.SourmashßSignature, list of sourmash.signature.SourmashSignature, or None if loading fails. - """ - self.logger.debug("Trying to load sourmash signature from file.") - try: - signatures = list(sourmash.load_file_as_signatures(sourmash_sig_path)) - self.logger.debug("Loaded %d sourmash signature(s) from file.", len(signatures)) - sigs_dict = {_sig.name: _sig for _sig in signatures} - self.logger.debug("Loaded sourmash signatures into sigs_dict: %s", sigs_dict) - return sigs_dict - except Exception as e: - self.logger.exception("Failed to load the sourmash signature from the file.", exc_info=e) - raise ValueError("An unexpected error occurred while loading the sourmash signature.") from e - - - def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature], - ksize: int = 51, scale: int = 10000, sig_type=SigType.SAMPLE, enable_logging: bool = False, **kwargs): - r""" - Initialize the SnipeSig with a sourmash signature object or a path to a signature. - - Parameters: - sourmash_sig (str or sourmash.signature.SourmashSignature): A path to a signature file or a signature object. - ksize (int): K-mer size. - scale (int): Scale value. - sig_type (SigType): Type of the signature. - enable_logging (bool): Flag to enable detailed logging. - **kwargs: Additional keyword arguments. - """ - # Initialize logging based on the flag - self.logger = logging.getLogger(self.__class__.__name__) - - # Configure the logger - if enable_logging: - self.logger.setLevel(logging.DEBUG) - if not self.logger.hasHandlers(): - # Create console handler - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - # Add formatter to handler - ch.setFormatter(formatter) - # Add handler to logger - self.logger.addHandler(ch) - self.logger.debug("Logging is enabled for SnipeSig.") - else: - self.logger.setLevel(logging.CRITICAL) - - # Initialize internal variables - self.logger.debug("Initializing SnipeSig with sourmash_sig: %s", sourmash_sig) - - self._scale = scale - self._ksize = ksize - self._md5sum = None - self._hashes = np.array([], dtype=np.uint64) - self._abundances = np.array([], dtype=np.uint32) - self._type = sig_type - self._name = None - self._filename = None - self._track_abundance = False - - sourmash_sigs: Dict[str, sourmash.signature.SourmashSignature] = {} - _sourmash_sig: Union[sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature] = None - - - self.logger.debug("Proceeding with a sigtype of %s", sig_type) - - - - - - - if not isinstance(sourmash_sig, (str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature)): - # if the str is not a file path - self.logger.error("Invalid type for sourmash_sig: %s", type(sourmash_sig).__name__) - raise TypeError(f"sourmash_sig must be a file path, sourmash.signature.SourmashSignature, or Frozensourmash_signature, got {type(sourmash_sig).__name__}") - - # Case 1: If sourmash_sig is already a valid sourmash signature object - if isinstance(sourmash_sig, (sourmash.signature.FrozenSourmashSignature, sourmash.signature.SourmashSignature)): - self.logger.debug("Loaded sourmash signature directly from object.") - sourmash_sigs = {sourmash_sig.name: sourmash_sig} - - # Case 2: If sourmash_sig is a string, try to load as JSON or a file - elif isinstance(sourmash_sig, str): - self.logger.debug("Attempting to load sourmash signature from string input.") - - # First, try loading from JSON - sourmash_sigs = self._try_load_from_json(sourmash_sig) - self.logger.debug("Loaded sourmash signature from JSON: %s", sourmash_sigs) - - # If JSON loading fails, try loading from file - if not sourmash_sigs: - sourmash_sigs = self._try_load_from_file(sourmash_sig) - - # If both attempts fail, raise an error - if not sourmash_sigs: - self.logger.error("Failed to load sourmash signature from the provided string.") - raise ValueError("An unexpected error occurred while loading the sourmash signature.") - - if sig_type == SigType.SAMPLE or sig_type == SigType.AMPLICON: - if len(sourmash_sigs) > 1: - self.logger.debug("Multiple signatures found in the input. Expected a single sample signature.") - # not supported at this time - raise ValueError("Loading multiple sample signatures is not supported at this time.") - elif len(sourmash_sigs) == 1: - self.logger.debug("Found a single signature in the sample sig input; Will use this signature.") - _sourmash_sig = list(sourmash_sigs.values())[0] - else: - self.logger.debug("No signature found in the input. Expected a single sample signature.") - raise ValueError("No signature found in the input. Expected a single sample signature.") - - elif sig_type == SigType.GENOME: - if len(sourmash_sigs) > 1: - for signame, sig in sourmash_sigs.items(): - if signame.endswith("-snipegenome"): - sig = sig.to_mutable() - sig.name = sig.name.replace("-snipegenome", "") - self.logger.debug("Found a genome signature with a snipe modified name. Restoring original name `%s`.", sig.name) - _sourmash_sig = sig - break - else: - self.logger.debug("Found multiple signature per the genome file, but none with a snipe modified name.") - raise ValueError("Found multiple signature per the genome file, but none with a snipe modified name.") - elif len(sourmash_sigs) == 1: - self.logger.debug("Found a single signature in the genome sig input; Will use this signature.") - _sourmash_sig = list(sourmash_sigs.values())[0] - else: - self.logger.debug("Unknown sigtype: %s", sig_type) - raise ValueError(f"Unknown sigtype: {sig_type}") - - self.logger.debug("Length of currently loaded signature: %d, with name: %s", len(_sourmash_sig), _sourmash_sig.name) - - # Extract properties from the loaded signature - self._ksize = _sourmash_sig.minhash.ksize - self._scale = _sourmash_sig.minhash.scaled - self._md5sum = _sourmash_sig.md5sum() - self._name = _sourmash_sig.name - self._filename = _sourmash_sig.filename - self._track_abundance = _sourmash_sig.minhash.track_abundance - - # If the signature does not track abundance, assume abundance of 1 for all hashes - if not self._track_abundance: - self.logger.debug("Signature does not track abundance. Setting all abundances to 1.") - self._abundances = np.ones(len(_sourmash_sig.minhash.hashes), dtype=np.uint32) - # self._track_abundance = True - else: - self._abundances = np.array(list(_sourmash_sig.minhash.hashes.values()), dtype=np.uint32) - - self._hashes = np.array(list(_sourmash_sig.minhash.hashes.keys()), dtype=np.uint64) - - # Sort the hashes and rearrange abundances accordingly - sorted_indices = np.argsort(self._hashes) - self._hashes = self._hashes[sorted_indices] - self._abundances = self._abundances[sorted_indices] - - self.logger.debug( - "Loaded sourmash signature from file: %s, name: %s, md5sum: %s, ksize: %d, scale: %d, " - "track_abundance: %s, type: %s, length: %d", - self._filename, self._name, self._md5sum, self._ksize, self._scale, - self._track_abundance, self._type, len(self._hashes) - ) - self.logger.debug("Hashes sorted during initialization.") - self.logger.debug("Sourmash signature loading completed successfully.") - - # Setters and getters - @property - def hashes(self) -> np.ndarray: - r"""Return a copy of the hashes array.""" - return self._hashes.view() - - @property - def abundances(self) -> np.ndarray: - r"""Return a copy of the abundances array.""" - return self._abundances.view() - - @property - def md5sum(self) -> str: - r"""Return the MD5 checksum of the signature.""" - return self._md5sum - - @property - def ksize(self) -> int: - r"""Return the k-mer size.""" - return self._ksize - - @property - def scale(self) -> int: - r"""Return the scale value.""" - return self._scale - - @property - def name(self) -> str: - r"""Return the name of the signature.""" - return self._name - - @property - def filename(self) -> str: - r"""Return the filename of the signature.""" - return self._filename - - @property - def sigtype(self) -> SigType: - r"""Return the type of the signature.""" - return self._type - - @property - def track_abundance(self) -> bool: - r"""Return whether the signature tracks abundance.""" - return self._track_abundance - - # Basic class methods - def get_name(self) -> str: - r"""Get the name of the signature.""" - return self._name - - # setter sigtype - @sigtype.setter - def sigtype(self, sigtype: SigType): - r""" - Set the type of the signature. - """ - self._type = sigtype - - def get_info(self) -> dict: - r""" - Get information about the signature. - - Returns: - dict: A dictionary containing signature information. - """ - info = { - "name": self._name, - "filename": self._filename, - "md5sum": self._md5sum, - "ksize": self._ksize, - "scale": self._scale, - "track_abundance": self._track_abundance, - "sigtype": self._type, - "num_hashes": len(self._hashes) - } - return info - - def __len__(self) -> int: - r"""Return the number of hashes in the signature.""" - return len(self._hashes) - - def __iter__(self) -> Iterator[tuple]: - r""" - Iterate over the hashes and their abundances. - - Yields: - tuple: A tuple containing (hash, abundance). - """ - for h, a in zip(self._hashes, self._abundances): - yield (h, a) - - def __contains__(self, hash_value: int) -> bool: - r""" - Check if a hash is present in the signature. - - Parameters: - hash_value (int): The hash value to check. - - Returns: - bool: True if the hash is present, False otherwise. - """ - # Utilize binary search since hashes are sorted - index = np.searchsorted(self._hashes, hash_value) - if index < len(self._hashes) and self._hashes[index] == hash_value: - return True - return False - - def __repr__(self) -> str: - return (f"SnipeSig(name={self._name}, ksize={self._ksize}, scale={self._scale}, " - f"type={self._type}, num_hashes={len(self._hashes)})") - - def __str__(self) -> str: - return self.__repr__() - - def __verify_snipe_signature(self, other: 'SnipeSig'): - r""" - Verify that the other object is a SnipeSig instance. - - Parameters: - other (SnipeSig): The other signature to verify. - - Raises: - ValueError: If the other object is not a SnipeSig instance. - """ - if not isinstance(other, SnipeSig): - msg = f"Provided sig ({type(other).__name__}) is not a SnipeSig instance." - self.logger.error(msg) - raise ValueError(msg) - - def __verify_matching_ksize_scale(self, other: 'SnipeSig'): - r""" - Verify that the ksize and scale match between two signatures. - - Parameters: - other (SnipeSig): The other signature to compare. - - Raises: - ValueError: If ksize or scale do not match. - """ - if self._ksize != other.ksize: - _e_msg = f"K-mer size does not match between the two signatures: {self._ksize} vs {other.ksize}." - self.logger.error(_e_msg) - raise ValueError(_e_msg) - if self._scale != other.scale: - _e_msg = f"Scale value does not match between the two signatures: {self._scale} vs {other.scale}." - self.logger.error(_e_msg) - raise ValueError(_e_msg) - - def _validate_abundance_operation(self, value: Union[int, None], operation: str): - r""" - Validate that the signature tracks abundance and that the provided value is a non-negative integer. - - Parameters: - value (int or None): The abundance value to validate. Can be None for operations that don't require a value. - operation (str): Description of the operation for logging purposes. - - Raises: - ValueError: If the signature does not track abundance or if the value is invalid. - """ - if not self._track_abundance and self.sigtype == SigType.SAMPLE: - self.logger.error("Cannot %s: signature does not track abundance.", operation) - raise ValueError("Signature does not track abundance.") - - if value is not None: - if not isinstance(value, int) or value < 0: - self.logger.error("%s requires a non-negative integer value.", operation.capitalize()) - raise ValueError(f"{operation.capitalize()} requires a non-negative integer value.") - - # Mask application method - def _apply_mask(self, mask: np.ndarray): - r""" - Apply a boolean mask to the hashes and abundances arrays. - Ensures that the sorted order is preserved. - - Parameters: - mask (np.ndarray): Boolean array indicating which elements to keep. - """ - self._hashes = self._hashes[mask] - self._abundances = self._abundances[mask] - - # Verify that the hashes remain sorted - if self._hashes.size > 1: - if not np.all(self._hashes[:-1] <= self._hashes[1:]): - self.logger.error("Hashes are not sorted after applying mask.") - raise RuntimeError("Hashes are not sorted after applying mask.") - self.logger.debug("Applied mask. Hashes remain sorted.") - - # Set operation methods - def union_sigs(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Combine this signature with another by summing abundances where hashes overlap. - - Given two signatures \( A \) and \( B \) with hash sets \( H_A \) and \( H_B \), - and their corresponding abundance functions \( a_A \) and \( a_B \), the union - signature \( C \) is defined as follows: - - - **Hash Set**: - - $$ - H_C = H_A \cup H_B - $$ - - - **Abundance Function**: - - $$ - a_C(h) = - \begin{cases} - a_A(h) + a_B(h), & \text{if } h \in H_A \cap H_B \\ - a_A(h), & \text{if } h \in H_A \setminus H_B \\ - a_B(h), & \text{if } h \in H_B \setminus H_A - \end{cases} - $$ - """ - self.__verify_snipe_signature(other) - self.__verify_matching_ksize_scale(other) - - self.logger.debug("Unioning signatures (including all unique hashes).") - - # Access internal arrays directly - self_hashes = self._hashes - self_abundances = self._abundances - other_hashes = other._hashes - other_abundances = other._abundances - - # Handle the case where 'other' does not track abundance - if not other.track_abundance: - self.logger.debug("Other signature does not track abundance. Setting abundances to 1.") - other_abundances = np.ones_like(other_abundances, dtype=np.uint32) - - # Combine hashes and abundances - combined_hashes = np.concatenate((self_hashes, other_hashes)) - combined_abundances = np.concatenate((self_abundances, other_abundances)) - - # Use numpy's unique function with return_inverse to sum abundances efficiently - unique_hashes, inverse_indices = np.unique(combined_hashes, return_inverse=True) - summed_abundances = np.zeros_like(unique_hashes, dtype=np.uint32) - - # Sum abundances for duplicate hashes - np.add.at(summed_abundances, inverse_indices, combined_abundances) - - # Handle potential overflow - summed_abundances = np.minimum(summed_abundances, np.iinfo(np.uint32).max) - - self.logger.debug("Union operation completed. Total hashes: %d", len(unique_hashes)) - - # Create a new SnipeSig instance - return self.create_from_hashes_abundances( - hashes=unique_hashes, - abundances=summed_abundances, - ksize=self._ksize, - scale=self._scale, - name=f"{self._name}_union_{other._name}", - filename=None, - enable_logging=self.logger.level <= logging.DEBUG - ) - - def _convert_to_sourmash_signature(self): - r""" - Convert the SnipeSig instance to a sourmash.signature.SourmashSignature object. - - Returns: - sourmash.signature.SourmashSignature: A new sourmash.signature.SourmashSignature instance. - """ - self.logger.debug("Converting SnipeSig to sourmash.signature.SourmashSignature.") - - mh = sourmash.minhash.MinHash(n=0, ksize=self._ksize, scaled=self._scale, track_abundance=self._track_abundance) - mh.set_abundances(dict(zip(self._hashes, self._abundances))) - self.sourmash_sig = sourmash.signature.SourmashSignature(mh, name=self._name, filename=self._filename) - self.logger.debug("Conversion to sourmash.signature.SourmashSignature completed.") - - def export(self, path) -> None: - r""" - Export the signature to a file. - - Parameters: - path (str): The path to save the signature to. - """ - self._convert_to_sourmash_signature() - with open(str(path), "wb") as fp: - sourmash.signature.save_signatures_to_json([self.sourmash_sig], fp) - - def export_to_string(self): - r""" - Export the signature to a JSON string. - - Returns: - str: JSON string representation of the signature. - """ - self._convert_to_sourmash_signature() - return sourmash.signature.save_signatures_to_json([self.sourmash_sig]).decode('utf-8') - - def intersection_sigs(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Compute the intersection of the current signature with another signature. - - This method keeps only the hashes that are common to both signatures, and retains the abundances from self. - - **Mathematical Explanation**: - - Let \( A \) and \( B \) be two signatures with sets of hashes \( H_A \) and \( H_B \), - and abundance functions \( a_A(h) \) and \( a_B(h) \), the intersection signature \( C \) has: - - - Hash set: - $$ - H_C = H_A \cap H_B - $$ - - - Abundance function: - $$ - a_C(h) = a_A(h), \quad \text{for } h \in H_C - $$ - - **Parameters**: - - `other (SnipeSig)`: Another `SnipeSig` instance to intersect with. - - **Returns**: - - `SnipeSig`: A new `SnipeSig` instance representing the intersection of the two signatures. - - **Raises**: - - `ValueError`: If `ksize` or `scale` do not match between signatures. - """ - self.__verify_snipe_signature(other) - self.__verify_matching_ksize_scale(other) - - self.logger.debug("Intersecting signatures.") - - # Use numpy's intersect1d function - common_hashes, self_indices, _ = np.intersect1d( - self._hashes, other._hashes, assume_unique=True, return_indices=True - ) - - if common_hashes.size == 0: - self.logger.debug("No common hashes found. Returning an empty signature.") - return self.create_from_hashes_abundances( - hashes=np.array([], dtype=np.uint64), - abundances=np.array([], dtype=np.uint32), - ksize=self._ksize, - scale=self._scale, - name=f"{self._name}_intersection_{other._name}", - filename=None, - enable_logging=self.logger.level <= logging.DEBUG - ) - - # Get the abundances from self - common_abundances = self._abundances[self_indices] - - self.logger.debug("Intersection operation completed. Total common hashes: %d", len(common_hashes)) - - # Create a new SnipeSig instance - return self.create_from_hashes_abundances( - hashes=common_hashes, - abundances=common_abundances, - ksize=self._ksize, - scale=self._scale, - name=f"{self._name}_intersection_{other._name}", - filename=None, - enable_logging=self.logger.level <= logging.DEBUG - ) - - def difference_sigs(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Compute the difference of the current signature with another signature. - - This method removes hashes that are present in the other signature from self, - keeping the abundances from self. - - **Mathematical Explanation**: - - Let \( A \) and \( B \) be two signatures with sets of hashes \( H_A \) and \( H_B \), - and abundance function \( a_A(h) \), the difference signature \( C \) has: - - - Hash set: - $$ - H_C = H_A \setminus H_B - $$ - - - Abundance function: - $$ - a_C(h) = a_A(h), \quad \text{for } h \in H_C - $$ - - **Parameters**: - - `other (SnipeSig)`: Another `SnipeSig` instance to subtract from the current signature. - - **Returns**: - - `SnipeSig`: A new `SnipeSig` instance representing the difference of the two signatures. - - **Raises**: - - `ValueError`: If `ksize` or `scale` do not match between signatures. - - `RuntimeError`: If zero hashes remain after difference. - """ - self.__verify_snipe_signature(other) - self.__verify_matching_ksize_scale(other) - - self.logger.debug("Differencing signatures.") - - # Use numpy's setdiff1d function - diff_hashes = np.setdiff1d(self._hashes, other._hashes, assume_unique=True) - - if diff_hashes.size == 0: - _e_msg = f"Difference operation resulted in zero hashes, which is not allowed for {self._name} and {other._name}." - self.logger.warning(_e_msg) - - # Get the indices of the hashes in self - mask = np.isin(self._hashes, diff_hashes, assume_unique=True) - diff_abundances = self._abundances[mask] - - self.logger.debug("Difference operation completed. Remaining hashes: %d", len(diff_hashes)) - - # Create a new SnipeSig instance - return self.create_from_hashes_abundances( - hashes=diff_hashes, - abundances=diff_abundances, - ksize=self._ksize, - scale=self._scale, - name=f"{self._name}_difference_{other._name}", - filename=None, - enable_logging=self.logger.level <= logging.DEBUG - ) - - def symmetric_difference_sigs(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Compute the symmetric difference of the current signature with another signature. - - This method retains hashes that are unique to each signature, with their respective abundances. - - **Mathematical Explanation**: - - Let \( A \) and \( B \) be two signatures with sets of hashes \( H_A \) and \( H_B \), - and abundance functions \( a_A(h) \) and \( a_B(h) \), the symmetric difference signature \( C \) has: - - - Hash set: - $$ - H_C = (H_A \setminus H_B) \cup (H_B \setminus H_A) - $$ - - - Abundance function: - $$ - a_C(h) = - \begin{cases} - a_A(h), & \text{for } h \in H_A \setminus H_B \\ - a_B(h), & \text{for } h \in H_B \setminus H_A \\ - \end{cases} - $$ - - **Parameters**: - - `other (SnipeSig)`: Another `SnipeSig` instance to compute the symmetric difference with. - - **Returns**: - - `SnipeSig`: A new `SnipeSig` instance representing the symmetric difference of the two signatures. - - **Raises**: - - `ValueError`: If `ksize` or `scale` do not match between signatures. - - `RuntimeError`: If zero hashes remain after symmetric difference. - """ - self.__verify_snipe_signature(other) - self.__verify_matching_ksize_scale(other) - - self.logger.debug("Computing symmetric difference of signatures.") - - # Hashes unique to self and other - unique_self_hashes = np.setdiff1d(self._hashes, other._hashes, assume_unique=True) - unique_other_hashes = np.setdiff1d(other._hashes, self._hashes, assume_unique=True) - - # Abundances for unique hashes - mask_self = np.isin(self._hashes, unique_self_hashes, assume_unique=True) - unique_self_abundances = self._abundances[mask_self] - - mask_other = np.isin(other._hashes, unique_other_hashes, assume_unique=True) - unique_other_abundances = other._abundances[mask_other] - - # Handle the case where 'other' does not track abundance - if not other.track_abundance: - self.logger.debug("Other signature does not track abundance. Setting abundances to 1.") - unique_other_abundances = np.ones_like(unique_other_abundances, dtype=np.uint32) - - # Combine hashes and abundances - combined_hashes = np.concatenate((unique_self_hashes, unique_other_hashes)) - combined_abundances = np.concatenate((unique_self_abundances, unique_other_abundances)) - - if combined_hashes.size == 0: - _e_msg = "Symmetric difference operation resulted in zero hashes, which is not allowed." - self.logger.error(_e_msg) - raise RuntimeError(_e_msg) - - # Sort combined hashes and abundances - sorted_indices = np.argsort(combined_hashes) - combined_hashes = combined_hashes[sorted_indices] - combined_abundances = combined_abundances[sorted_indices] - - self.logger.debug("Symmetric difference operation completed. Total unique hashes: %d", len(combined_hashes)) - - # Create a new SnipeSig instance - return self.create_from_hashes_abundances( - hashes=combined_hashes, - abundances=combined_abundances, - ksize=self._ksize, - scale=self._scale, - name=f"{self._name}_symmetric_difference_{other._name}", - filename=None, - enable_logging=self.logger.level <= logging.DEBUG - ) - - # Magic methods for union operations - def __add__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the + operator. - Includes all unique hashes from both signatures and sums their abundances where hashes overlap, - returning a new signature. - - Returns: - SnipeSig: Union of self and other. - """ - return self.union_sigs(other) - - def __iadd__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the += operator. - Includes all unique hashes from both signatures and sums their abundances where hashes overlap, - modifying self in-place. - - Returns: - SnipeSig: Updated self after addition. - """ - union_sig = self.union_sigs(other) - self._update_from_union(union_sig) - return self - - def __or__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the | operator. - Includes all unique hashes from both signatures and sums their abundances where hashes overlap, - returning a new signature. - - Returns: - SnipeSig: Union of self and other. - """ - return self.union_sigs(other) - - def __ior__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the |= operator. - Includes all unique hashes from both signatures and sums their abundances where hashes overlap, - modifying self in-place. - - Returns: - SnipeSig: Updated self after union. - """ - union_sig = self.union_sigs(other) - self._update_from_union(union_sig) - return self - - def __sub__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the - operator. - Removes hashes present in other from self, keeping abundances from self, - returning a new signature. - - Returns: - SnipeSig: Difference of self and other. - """ - return self.difference_sigs(other) - - def __isub__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the -= operator. - Removes hashes present in other from self, keeping abundances from self, - modifying self in-place. - - Returns: - SnipeSig: Updated self after difference. - - Raises: - RuntimeError: If zero hashes remain after difference. - """ - difference_sig = self.difference_sigs(other) - self._update_from_union(difference_sig) - return self - - def __xor__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the ^ operator. - Keeps unique hashes from each signature with their respective abundances, returning a new signature. - - Returns: - SnipeSig: Symmetric difference of self and other. - """ - return self.symmetric_difference_sigs(other) - - def __ixor__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the ^= operator. - Keeps unique hashes from each signature with their respective abundances, modifying self in-place. - - Returns: - SnipeSig: Updated self after symmetric difference. - - Raises: - RuntimeError: If zero hashes remain after symmetric difference. - """ - symmetric_diff_sig = self.symmetric_difference_sigs(other) - self._update_from_union(symmetric_diff_sig) - return self - - def __and__(self, other: 'SnipeSig') -> 'SnipeSig': - r""" - Implements the & operator. - Keeps common hashes and retains abundances from self only, returning a new signature. - - Returns: - SnipeSig: Intersection of self and other. - """ - return self.intersection_sigs(other) - - def _update_from_union(self, other: 'SnipeSig'): - r""" - Update self's hashes and abundances from another SnipeSig instance. - - Parameters: - other (SnipeSig): The other SnipeSig instance to update from. - """ - self._hashes = other.hashes - self._abundances = other.abundances - self._name = other.name - self._filename = other.filename - self._md5sum = other.md5sum - self._track_abundance = other.track_abundance - # No need to update ksize and scale since they are verified to match - - @classmethod - def create_from_hashes_abundances(cls, hashes: np.ndarray, abundances: np.ndarray, - ksize: int, scale: int, name: str = None, - filename: str = None, enable_logging: bool = False, sig_type: SigType = SigType.SAMPLE) -> 'SnipeSig': - """ - Internal method to create a SnipeSig instance from hashes and abundances. - - Parameters: - hashes (np.ndarray): Array of hash values. - abundances (np.ndarray): Array of abundance values corresponding to the hashes. - ksize (int): K-mer size. - scale (int): Scale value. - name (str): Optional name for the signature. - filename (str): Optional filename for the signature. - sig_type (SigType): Type of the signature. - enable_logging (bool): Flag to enable logging. - - Returns: - SnipeSig: A new SnipeSig instance. - """ - # Create a mock sourmash signature object - mh = sourmash.minhash.MinHash(n=0, ksize=ksize, scaled=scale, track_abundance=True) - mh.set_abundances(dict(zip(hashes, abundances))) - sig = sourmash.signature.SourmashSignature(mh, name=name or "", filename=filename or "") - return cls(sourmash_sig=sig, sig_type=sig_type, enable_logging=enable_logging) - - # Aggregation Operations - @classmethod - def sum_signatures(cls, signatures: List['SnipeSig'], name: str = "summed_signature", - filename: str = None, enable_logging: bool = False) -> 'SnipeSig': - - r""" - Sum multiple SnipeSig instances by including all unique hashes and summing their abundances where hashes overlap. - This method utilizes a heap-based multi-way merge algorithm for enhanced efficiency when handling thousands of signatures. - - $$ - \text{Sum}(A_1, A_2, \dots, A_n) = \bigcup_{i=1}^{n} A_i - $$ - - For each hash \( h \), its total abundance is: - $$ - \text{abundance}(h) = \sum_{i=1}^{n} \text{abundance}_i(h) - $$ - - **Mathematical Explanation**: - - - **Union of Signatures**: - The summation of signatures involves creating a union of all unique k-mers (hashes) present across the input signatures. - - - **Total Abundance Calculation**: - For each unique hash \( h \), the total abundance is the sum of its abundances across all signatures where it appears. - - - **Algorithm Efficiency**: - By using a min-heap to perform a multi-way merge of sorted hash arrays, the method ensures that each hash is processed in ascending order without the need to store all hashes in memory simultaneously. - - **Parameters**: - - `signatures (List[SnipeSig])`: List of `SnipeSig` instances to sum. - - `name (str)`: Optional name for the resulting signature. - - `filename (str)`: Optional filename for the resulting signature. - - `enable_logging (bool)`: Flag to enable detailed logging. - - **Returns**: - - `SnipeSig`: A new `SnipeSig` instance representing the sum of the signatures. - - **Raises**: - - `ValueError`: If the signatures list is empty or if `ksize`/`scale` do not match across signatures. - - `RuntimeError`: If an error occurs during the summation process. - """ - if not signatures: - raise ValueError("No signatures provided for summation.") - - # Verify that all signatures have the same ksize, scale, and track_abundance - first_sig = signatures[0] - ksize = first_sig.ksize - scale = first_sig.scale - track_abundance = first_sig.track_abundance - - for sig in signatures[1:]: - if sig.ksize != ksize or sig.scale != scale: - raise ValueError("All signatures must have the same ksize and scale.") - if sig.track_abundance != track_abundance: - raise ValueError("All signatures must have the same track_abundance setting.") - - # Initialize iterators for each signature's hashes and abundances - iterators = [] - for sig in signatures: - it = iter(zip(sig.hashes, sig.abundances)) - try: - first_hash, first_abundance = next(it) - iterators.append((first_hash, first_abundance, it)) - except StopIteration: - continue # Skip empty signatures - - if not iterators: - raise ValueError("All provided signatures are empty.") - - # Initialize the heap with the first element from each iterator - heap = [] - for idx, (hash_val, abundance, it) in enumerate(iterators): - heap.append((hash_val, abundance, idx)) - heapq.heapify(heap) - - # Prepare lists to collect the summed hashes and abundances - summed_hashes = [] - summed_abundances = [] - - while heap: - current_hash, current_abundance, idx = heapq.heappop(heap) - # Initialize total abundance for the current_hash - total_abundance = current_abundance - - # Check if the next element in the heap has the same hash - while heap and heap[0][0] == current_hash: - _, abundance, same_idx = heapq.heappop(heap) - total_abundance += abundance - # Push the next element from the same iterator - try: - next_hash, next_abundance = next(iterators[same_idx][2]) - heapq.heappush(heap, (next_hash, next_abundance, same_idx)) - except StopIteration: - pass # No more elements in this iterator - - # Append the summed hash and abundance - summed_hashes.append(current_hash) - summed_abundances.append(total_abundance) - - # Push the next element from the current iterator - try: - next_hash, next_abundance = next(iterators[idx][2]) - heapq.heappush(heap, (next_hash, next_abundance, idx)) - except StopIteration: - pass # No more elements in this iterator - - # Convert the results to NumPy arrays for efficient storage and processing - summed_hashes = np.array(summed_hashes, dtype=np.uint64) - summed_abundances = np.array(summed_abundances, dtype=np.uint32) - - # Handle potential overflow by capping at the maximum value of uint32 - summed_abundances = np.minimum(summed_abundances, np.iinfo(np.uint32).max) - - # Create a new SnipeSig instance from the summed hashes and abundances - summed_signature = cls.create_from_hashes_abundances( - hashes=summed_hashes, - abundances=summed_abundances, - ksize=ksize, - scale=scale, - name=name, - filename=filename, - enable_logging=enable_logging - ) - - return summed_signature - - @staticmethod - def get_unique_signatures(signatures: Dict[str, 'SnipeSig']) -> Dict[str, 'SnipeSig']: - """ - Extract unique signatures from a dictionary of SnipeSig instances. - - For each signature, the unique_sig contains only the hashes that do not overlap with any other signature. - - Parameters: - signatures (Dict[str, SnipeSig]): A dictionary mapping signature names to SnipeSig instances. - - Returns: - Dict[str, SnipeSig]: A dictionary mapping signature names to their unique SnipeSig instances. - - Raises: - ValueError: If the input dictionary is empty or if signatures have mismatched ksize/scale. - """ - if not signatures: - raise ValueError("The input signatures dictionary is empty.") - - # Extract ksize and scale from the first signature - first_name, first_sig = next(iter(signatures.items())) - ksize = first_sig.ksize - scale = first_sig.scale - - # Verify that all signatures have the same ksize and scale - for name, sig in signatures.items(): - if sig.ksize != ksize or sig.scale != scale: - raise ValueError(f"Signature '{name}' has mismatched ksize or scale.") - - # Aggregate all hashes from all signatures - all_hashes = np.concatenate([sig.hashes for sig in signatures.values()]) - - # Count the occurrences of each hash - unique_hashes, counts = np.unique(all_hashes, return_counts=True) - - # Identify hashes that are unique across all signatures (count == 1) - unique_across_all = unique_hashes[counts == 1] - - # Convert to a set for faster membership testing - unique_set = set(unique_across_all) - - unique_signatures = {} - - for name, sig in signatures.items(): - # Find hashes in the current signature that are unique across all signatures - mask_unique = np.isin(sig.hashes, list(unique_set)) - - # Extract unique hashes and their abundances - unique_hashes_sig = sig.hashes[mask_unique] - unique_abundances_sig = sig.abundances[mask_unique] - - # Create a new SnipeSig instance with the unique hashes and abundances - unique_sig = SnipeSig.create_from_hashes_abundances( - hashes=unique_hashes_sig, - abundances=unique_abundances_sig, - ksize=ksize, - scale=scale, - name=f"{name}_unique", - filename=None, - enable_logging=False, # Set to True if you want logging for the new signatures - sig_type=SigType.SAMPLE # Adjust sig_type as needed - ) - - unique_signatures[name] = unique_sig - - return unique_signatures - - - @classmethod - def common_hashes(cls, signatures: List['SnipeSig'], name: str = "common_hashes_signature", - filename: str = None, enable_logging: bool = False) -> 'SnipeSig': - r""" - Compute the intersection of multiple SnipeSig instances, returning a new SnipeSig containing - only the hashes present in all signatures, with abundances set to the minimum abundance across signatures. - - This method uses a heap-based multi-way merge algorithm for efficient computation, - especially when handling a large number of signatures with sorted hashes. - - **Mathematical Explanation**: - - Given signatures \( A_1, A_2, \dots, A_n \) with hash sets \( H_1, H_2, \dots, H_n \), - the intersection signature \( C \) has: - - - Hash set: - $$ - H_C = \bigcap_{i=1}^{n} H_i - $$ - - - Abundance function: - $$ - a_C(h) = \min_{i=1}^{n} a_i(h), \quad \text{for } h \in H_C - $$ - - **Parameters**: - - `signatures (List[SnipeSig])`: List of `SnipeSig` instances to compute the intersection. - - `name (str)`: Optional name for the resulting signature. - - `filename (str)`: Optional filename for the resulting signature. - - `enable_logging (bool)`: Flag to enable detailed logging. - - **Returns**: - - `SnipeSig`: A new `SnipeSig` instance representing the intersection of the signatures. - - **Raises**: - - `ValueError`: If the signatures list is empty or if `ksize`/`scale` do not match across signatures. - """ - if not signatures: - raise ValueError("No signatures provided for intersection.") - - # Verify that all signatures have the same ksize and scale - first_sig = signatures[0] - ksize = first_sig.ksize - scale = first_sig.scale - for sig in signatures[1:]: - if sig.ksize != ksize or sig.scale != scale: - raise ValueError("All signatures must have the same ksize and scale.") - - num_signatures = len(signatures) - iterators = [] - for sig in signatures: - it = iter(zip(sig.hashes, sig.abundances)) - try: - first_hash, first_abundance = next(it) - iterators.append((first_hash, first_abundance, it)) - except StopIteration: - # One of the signatures is empty; intersection is empty - return cls.create_from_hashes_abundances( - hashes=np.array([], dtype=np.uint64), - abundances=np.array([], dtype=np.uint32), - ksize=ksize, - scale=scale, - name=name, - filename=filename, - enable_logging=enable_logging - ) - - # Initialize the heap with the first element from each iterator - heap = [] - for idx, (hash_val, abundance, it) in enumerate(iterators): - heap.append((hash_val, abundance, idx)) - heapq.heapify(heap) - - common_hashes = [] - common_abundances = [] - - while heap: - # Pop all entries with the smallest hash - current_hash, current_abundance, idx = heapq.heappop(heap) - same_hash_entries = [(current_hash, current_abundance, idx)] - - # Collect all entries in the heap that have the same current_hash - while heap and heap[0][0] == current_hash: - h, a, i = heapq.heappop(heap) - same_hash_entries.append((h, a, i)) - - if len(same_hash_entries) == num_signatures: - # The current_hash is present in all signatures - # Take the minimum abundance across signatures - min_abundance = min(entry[1] for entry in same_hash_entries) - common_hashes.append(current_hash) - common_abundances.append(min_abundance) - - # Push the next element from each iterator that had the current_hash - for entry in same_hash_entries: - h, a, i = entry - try: - next_hash, next_abundance = next(iterators[i][2]) - heapq.heappush(heap, (next_hash, next_abundance, i)) - except StopIteration: - pass # Iterator exhausted - - # Convert the results to NumPy arrays - if not common_hashes: - # No common hashes found - unique_hashes = np.array([], dtype=np.uint64) - unique_abundances = np.array([], dtype=np.uint32) - else: - unique_hashes = np.array(common_hashes, dtype=np.uint64) - unique_abundances = np.array(common_abundances, dtype=np.uint32) - - # Create a new SnipeSig instance from the common hashes and abundances - common_signature = cls.create_from_hashes_abundances( - hashes=unique_hashes, - abundances=unique_abundances, - ksize=ksize, - scale=scale, - name=name, - filename=filename, - enable_logging=enable_logging - ) - - return common_signature - - def copy(self) -> 'SnipeSig': - r""" - Create a copy of the current SnipeSig instance. - - Returns: - SnipeSig: A new instance that is a copy of self. - """ - return SnipeSig(sourmash_sig=self.export_to_string(), sig_type=self.sigtype, enable_logging=self.logger.level <= logging.DEBUG) - - # Implement the __radd__ method to support sum() - def __radd__(self, other: Union[int, 'SnipeSig']) -> 'SnipeSig': - r""" - Implements the right-hand + operator to support sum(). - - Returns: - SnipeSig: Union of self and other. - """ - return self.__radd_sum__(other) - - # Override the __sum__ method - def __radd_sum__(self, other: Union[int, 'SnipeSig']) -> 'SnipeSig': - r""" - Internal helper method to support the sum() function. - - Parameters: - other (int or SnipeSig): The other object to add. If other is 0, return self. - - Returns: - SnipeSig: The result of the addition. - """ - if other == 0: - return self - if not isinstance(other, SnipeSig): - raise TypeError(f"Unsupported operand type(s) for +: 'SnipeSig' and '{type(other).__name__}'") - return self.union_sigs(other) - - def reset_abundance(self, new_abundance: int = 1): - r""" - Reset all abundances to a specified value. - - This method sets the abundance of every hash in the signature to the specified `new_abundance` value. - - **Mathematical Explanation**: - - For each hash \( h \) in the signature, the abundance function is updated to: - $$ - a(h) = \text{new\_abundance} - $$ - - **Parameters**: - - `new_abundance (int)`: The new abundance value to set for all hashes. Default is 1. - - **Raises**: - - `ValueError`: If the signature does not track abundance or if `new_abundance` is invalid. - """ - - self._validate_abundance_operation(new_abundance, "reset abundance") - - self._abundances[:] = new_abundance - self.logger.debug("Reset all abundances to %d.", new_abundance) - - def keep_min_abundance(self, min_abundance: int): - r""" - Keep only hashes with abundances greater than or equal to a minimum threshold. - - This method removes hashes whose abundances are less than the specified `min_abundance`. - - **Mathematical Explanation**: - - The updated hash set \( H' \) is: - $$ - H' = \{ h \in H \mid a(h) \geq \text{min\_abundance} \} - $$ - - **Parameters**: - - `min_abundance (int)`: The minimum abundance threshold. - - **Raises**: - - `ValueError`: If the signature does not track abundance or if `min_abundance` is invalid. - """ - self._validate_abundance_operation(min_abundance, "keep minimum abundance") - - mask = self._abundances >= min_abundance - self._apply_mask(mask) - self.logger.debug("Kept hashes with abundance >= %d.", min_abundance) - - def keep_max_abundance(self, max_abundance: int): - r""" - Keep only hashes with abundances less than or equal to a maximum threshold. - - This method removes hashes whose abundances are greater than the specified `max_abundance`. - - **Mathematical Explanation**: - - The updated hash set \( H' \) is: - $$ - H' = \{ h \in H \mid a(h) \leq \text{max\_abundance} \} - $$ - - **Parameters**: - - `max_abundance (int)`: The maximum abundance threshold. - - **Raises**: - - `ValueError`: If the signature does not track abundance or if `max_abundance` is invalid. - """ - self._validate_abundance_operation(max_abundance, "keep maximum abundance") - - mask = self._abundances <= max_abundance - self._apply_mask(mask) - self.logger.debug("Kept hashes with abundance <= %d.", max_abundance) - - def trim_below_median(self): - r""" - Trim hashes with abundances below the median abundance. - - This method removes all hashes whose abundances are less than the median abundance of the signature. - - **Mathematical Explanation**: - - Let \\( m \\) be the median of \\( \\{ a(h) \mid h \in H \\} \\). - The updated hash set \\( H' \\) is: - - $$ - H' = \\{ h \in H \mid a(h) \geq m \\} - $$ - - **Raises**: - - `ValueError`: If the signature does not track abundance. - """ - - self._validate_abundance_operation(None, "trim below median") - - if len(self._abundances) == 0: - self.logger.debug("No hashes to trim based on median abundance.") - return - - median = np.median(self._abundances) - mask = self._abundances >= median - self._apply_mask(mask) - self.logger.debug("Trimmed hashes with abundance below median (%f).", median) - - def count_singletons(self) -> int: - r""" - Return the number of hashes with abundance equal to 1. - - Returns: - int: Number of singletons. - - Raises: - ValueError: If the signature does not track abundance. - """ - self._validate_abundance_operation(None, "count singletons") - - count = np.sum(self._abundances == 1) - self.logger.debug("Number of singletons (abundance == 1): %d", count) - return int(count) - - def trim_singletons(self): - r""" - Remove hashes with abundance equal to 1. - - This method removes all hashes that are singletons (abundance equals 1). - - **Mathematical Explanation**: - - The updated hash set \( H' \) is: - $$ - H' = \{ h \in H \mid a(h) \neq 1 \} - $$ - - **Raises**: - - `ValueError`: If the signature does not track abundance. - """ - self._validate_abundance_operation(None, "trim singletons") - - mask = self._abundances != 1 - self._apply_mask(mask) - self.logger.debug("Trimmed hashes with abundance equal to 1.") - - # New Properties Implemented as per Request - - @property - def total_abundance(self) -> int: - r""" - Return the total abundance (sum of all abundances). - - Returns: - int: Total abundance. - """ - self._validate_abundance_operation(None, "calculate total abundance") - - total = int(np.sum(self._abundances)) - self.logger.debug("Total abundance: %d", total) - return total - - @property - def mean_abundance(self) -> float: - r""" - Return the mean (average) abundance. - - Returns: - float: Mean abundance. - """ - self._validate_abundance_operation(None, "calculate mean abundance") - - if len(self._abundances) == 0: - self.logger.debug("No abundances to calculate mean.") - return 0.0 - - mean = float(np.mean(self._abundances)) # Changed to float - self.logger.debug("Mean abundance: %f", mean) - return mean - - @property - def get_sample_stats(self) -> dict: - r""" - Retrieve statistical information about the signature. - - This property computes and returns a dictionary containing various statistics of the signature, such as total abundance, mean and median abundances, number of singletons, and total number of hashes. - - **Returns**: - - `dict`: A dictionary containing sample statistics: - - `total_abundance`: Sum of abundances. - - `mean_abundance`: Mean abundance. - - `median_abundance`: Median abundance. - - `num_singletons`: Number of hashes with abundance equal to 1. - - `num_hashes`: Total number of hashes. - - `ksize`: K-mer size. - - `scale`: Scale value. - - `name`: Name of the signature. - - `filename`: Filename of the signature. - """ - - # if self.sigtype != SigType.SAMPLE then don't return abundance stats - - stats = { - "num_hashes": len(self._hashes), - "ksize": self._ksize, - "scale": self._scale, - "name": self._name, - "filename": self._filename - } - - if self.sigtype != SigType.SAMPLE: - stats["total_abundance"] = None - stats["mean_abundance"] = None - stats["median_abundance"] = None - stats["num_singletons"] = None - else: - stats["total_abundance"] = self.total_abundance - stats["mean_abundance"] = self.mean_abundance - stats["median_abundance"] = self.median_abundance - stats["num_singletons"] = self.count_singletons() - - return stats - - @property - def median_abundance(self) -> float: - r""" - Return the median abundance. - - Returns: - float: Median abundance. - - Raises: - ValueError: If the signature does not track abundance. - """ - self._validate_abundance_operation(None, "calculate median abundance") - - if len(self._abundances) == 0: - self.logger.debug("No abundances to calculate median.") - return 0.0 - - median = float(np.median(self._abundances)) # Changed to float - self.logger.debug("Median abundance: %f", median) - return median diff --git a/tests/api/test_api.py b/tests/api/test_api.py index e77a452..fa46bee 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -7,7 +7,7 @@ import numpy as np import sourmash -from snipe.api import SnipeSig +from snipe.api.snipe_sig import SnipeSig from snipe.api.enums import SigType def sample_signature_json(): diff --git a/tests/api/test_reference_qc.py b/tests/api/test_reference_qc.py index e48c294..a1cc77a 100644 --- a/tests/api/test_reference_qc.py +++ b/tests/api/test_reference_qc.py @@ -7,7 +7,7 @@ import numpy as np import sourmash -from snipe.api import SnipeSig +from snipe.api.snipe_sig import SnipeSig from snipe.api.enums import SigType from snipe.api.reference_QC import ReferenceQC @@ -918,6 +918,303 @@ def test_calculate_sex_chrs_metrics(self): # Y-Coverage = (0/3) / (5/6) = 0 / 0.8333 = 0.0 expected_ycoverage = 0.0 self.assertAlmostEqual(metrics["Y-Coverage"], expected_ycoverage, places=4) + + def test_nonref_consume_from_vars_basic(self): + """ + Test that nonref_consume_from_vars correctly assigns non-reference k-mers to variables in vars_order. + """ + # Create a sample signature with some non-reference k-mers + # Reference signature has [10,20,30,40,50,60] + # Sample signature has [10,20,30,40,50,70,80,90] + sample_sig_nonref = self.create_test_signature( + hashes=[10, 20, 30, 40, 50, 70, 80, 90], + abundances=[1, 2, 3, 4, 5, 6, 7, 8], + name="test_sample_nonref", + sig_type=SigType.SAMPLE + ) + # Reference signature as per setUp: [10,20,30,40,50,60] + # Non-reference k-mers: [70,80,90] + + # Define variables with overlapping k-mers + vars_signatures = { + "var_A": self.create_test_signature( + hashes=[70, 80], + abundances=[6, 7], + name="var_A", + sig_type=SigType.SAMPLE + ), + "var_B": self.create_test_signature( + hashes=[80, 90], + abundances=[7, 8], + name="var_B", + sig_type=SigType.SAMPLE + ) + } + vars_order = ["var_A", "var_B"] + + qc = ReferenceQC( + sample_sig=sample_sig_nonref, + reference_sig=self.reference_sig, + amplicon_sig=None, + enable_logging=False + ) + + nonref_stats = qc.nonref_consume_from_vars(vars=vars_signatures, vars_order=vars_order) + + # Expected: + # var_A consumes [70,80]: total abundance = 6 + 7 = 13 + # Coverage index for var_A: 2 / 3 ≈ 0.6667 + # var_B consumes [90]: total abundance = 8 + # Coverage index for var_B: 1 / 3 ≈ 0.3333 + # non-var: 0 + expected_stats = { + "var_A non-genomic total k-mer abundance": 13, + "var_A non-genomic coverage index": 2 / 3, + "var_B non-genomic total k-mer abundance": 8, + "var_B non-genomic coverage index": 1 / 3, + "non-var non-genomic total k-mer abundance": 0, + "non-var non-genomic coverage index": 0 / 3 + } + + # Verify the stats + for key, value in expected_stats.items(): + self.assertAlmostEqual(nonref_stats.get(key, None), value, places=4, msg=f"Mismatch in {key}") + + def test_nonref_consume_from_vars_overlapping_vars(self): + """ + Test that nonref_consume_from_vars handles overlapping variables correctly, consuming k-mers only once. + """ + # Create a sample signature with some non-reference k-mers + sample_sig_nonref = self.create_test_signature( + hashes=[70, 80, 90, 100], + abundances=[6, 7, 8, 9], + name="test_sample_nonref_overlap", + sig_type=SigType.SAMPLE + ) + # Non-reference k-mers: [70,80,90,100] + + # Define variables with overlapping k-mers + vars_signatures = { + "var_A": self.create_test_signature( + hashes=[70, 80], + abundances=[6, 7], + name="var_A", + sig_type=SigType.SAMPLE + ), + "var_B": self.create_test_signature( + hashes=[80, 90], + abundances=[7, 8], + name="var_B", + sig_type=SigType.SAMPLE + ), + "var_C": self.create_test_signature( + hashes=[90, 100], + abundances=[8, 9], + name="var_C", + sig_type=SigType.SAMPLE + ) + } + vars_order = ["var_A", "var_B", "var_C"] + + qc = ReferenceQC( + sample_sig=sample_sig_nonref, + reference_sig=self.reference_sig, + amplicon_sig=None, + enable_logging=False + ) + + nonref_stats = qc.nonref_consume_from_vars(vars=vars_signatures, vars_order=vars_order) + + # Expected: + # var_A consumes [70,80]: total abundance = 6 + 7 = 13 + # Coverage index for var_A: 2 / 4 = 0.5 + # var_B consumes [90]: total abundance = 8 + # Coverage index for var_B: 1 / 4 = 0.25 + # var_C consumes [100]: total abundance = 9 + # Coverage index for var_C: 1 / 4 = 0.25 + # non-var: 0 + expected_stats = { + "var_A non-genomic total k-mer abundance": 13, + "var_A non-genomic coverage index": 2 / 4, + "var_B non-genomic total k-mer abundance": 8, + "var_B non-genomic coverage index": 1 / 4, + "var_C non-genomic total k-mer abundance": 9, + "var_C non-genomic coverage index": 1 / 4, + "non-var non-genomic total k-mer abundance": 0, + "non-var non-genomic coverage index": 0 / 4 + } + + # Verify the stats + for key, value in expected_stats.items(): + self.assertAlmostEqual(nonref_stats.get(key, None), value, places=4, msg=f"Mismatch in {key}") + + def test_nonref_consume_from_vars_empty_vars(self): + """ + Test that nonref_consume_from_vars handles empty vars correctly. + """ + # Create a sample signature with some non-reference k-mers + sample_sig_nonref = self.create_test_signature( + hashes=[70, 80, 90], + abundances=[6, 7, 8], + name="test_sample_nonref_empty_vars", + sig_type=SigType.SAMPLE + ) + # Non-reference k-mers: [70,80,90] + + # Define empty vars + vars_signatures = {} + vars_order = [] + + qc = ReferenceQC( + sample_sig=sample_sig_nonref, + reference_sig=self.reference_sig, + amplicon_sig=None, + enable_logging=False + ) + + nonref_stats = qc.nonref_consume_from_vars(vars=vars_signatures, vars_order=vars_order) + + # Expected: + # non-var consumes all k-mers + expected_stats = { + "non-var non-genomic total k-mer abundance": 21, # 6 + 7 + 8 + "non-var non-genomic coverage index": 3 / 3 # 3 k-mers + } + + # Verify the stats + self.assertEqual(len(nonref_stats), 2) + self.assertAlmostEqual(nonref_stats.get("non-var non-genomic total k-mer abundance"), 21, places=4) + self.assertAlmostEqual(nonref_stats.get("non-var non-genomic coverage index"), 1.0, places=4) + + def test_nonref_consume_from_vars_vars_order_not_matching_vars(self): + """ + Test that nonref_consume_from_vars raises ValueError when vars_order contains variables not in vars. + """ + # Create a sample signature with some non-reference k-mers + sample_sig_nonref = self.create_test_signature( + hashes=[70, 80, 90], + abundances=[6, 7, 8], + name="test_sample_nonref_invalid_order", + sig_type=SigType.SAMPLE + ) + # Non-reference k-mers: [70,80,90] + + # Define vars with one variable + vars_signatures = { + "var_A": self.create_test_signature( + hashes=[70], + abundances=[6], + name="var_A", + sig_type=SigType.SAMPLE + ) + } + # vars_order includes a variable not in vars_signatures + vars_order = ["var_A", "var_B"] + + qc = ReferenceQC( + sample_sig=sample_sig_nonref, + reference_sig=self.reference_sig, + amplicon_sig=None, + enable_logging=False + ) + + with self.assertRaises(ValueError): + qc.nonref_consume_from_vars(vars=vars_signatures, vars_order=vars_order) + + def test_nonref_consume_from_vars_no_nonref_kmers(self): + """ + Test that nonref_consume_from_vars returns empty dict when there are no non-reference k-mers. + """ + # Create a sample signature identical to reference signature + sample_sig_identical = self.create_test_signature( + hashes=[10,20,30,40,50,60], + abundances=[1,2,3,4,5,6], + name="test_sample_identical", + sig_type=SigType.SAMPLE + ) + # Non-reference k-mers: none + + # Define variables + vars_signatures = { + "var_A": self.create_test_signature( + hashes=[70], + abundances=[7], + name="var_A", + sig_type=SigType.SAMPLE + ) + } + vars_order = ["var_A"] + + qc = ReferenceQC( + sample_sig=sample_sig_identical, + reference_sig=self.reference_sig, + amplicon_sig=None, + enable_logging=False + ) + + nonref_stats = qc.nonref_consume_from_vars(vars=vars_signatures, vars_order=vars_order) + + # Expected: empty dict since no non-reference k-mers + self.assertEqual(nonref_stats, {}) + + def test_nonref_consume_from_vars_all_kmers_consumed(self): + """ + Test that nonref_consume_from_vars correctly reports when all non-reference k-mers are consumed by variables. + """ + # Create a sample signature with some non-reference k-mers + sample_sig_nonref = self.create_test_signature( + hashes=[70, 80, 90], + abundances=[6, 7, 8], + name="test_sample_nonref_all_consumed", + sig_type=SigType.SAMPLE + ) + # Non-reference k-mers: [70,80,90] + + # Define variables that consume all non-reference k-mers + vars_signatures = { + "var_A": self.create_test_signature( + hashes=[70, 80], + abundances=[6, 7], + name="var_A", + sig_type=SigType.SAMPLE + ), + "var_B": self.create_test_signature( + hashes=[90], + abundances=[8], + name="var_B", + sig_type=SigType.SAMPLE + ) + } + vars_order = ["var_A", "var_B"] + + qc = ReferenceQC( + sample_sig=sample_sig_nonref, + reference_sig=self.reference_sig, + amplicon_sig=None, + enable_logging=False + ) + + nonref_stats = qc.nonref_consume_from_vars(vars=vars_signatures, vars_order=vars_order) + + # Expected: + # var_A consumes [70,80]: total abundance = 6 + 7 = 13 + # Coverage index for var_A: 2 / 3 ≈ 0.6667 + # var_B consumes [90]: total abundance = 8 + # Coverage index for var_B: 1 / 3 ≈ 0.3333 + # non-var: 0 + expected_stats = { + "var_A non-genomic total k-mer abundance": 13, + "var_A non-genomic coverage index": 2 / 3, + "var_B non-genomic total k-mer abundance": 8, + "var_B non-genomic coverage index": 1 / 3, + "non-var non-genomic total k-mer abundance": 0, + "non-var non-genomic coverage index": 0 / 3 + } + + # Verify the stats + for key, value in expected_stats.items(): + self.assertAlmostEqual(nonref_stats.get(key, None), value, places=4, msg=f"Mismatch in {key}") + From 5e4d6dd8ae8446903e89f84382b7fee299830c2b Mon Sep 17 00:00:00 2001 From: Mohamed Abuelanin Date: Mon, 14 Oct 2024 17:12:46 -0700 Subject: [PATCH 6/7] adding var options --- .pylintrc | 1 + .vscode/settings.json | 4 + docs/docs/Sketch.md | 3 + docs/docs/SnipeSig.md | 2 +- pyproject.toml | 4 + src/snipe/api/reference_QC.py | 465 +++++++++++++++++++++++++++++++++- src/snipe/api/sketch.py | 5 +- src/snipe/api/snipe_sig.py | 107 ++++---- src/snipe/cli/main.py | 74 +++++- 9 files changed, 585 insertions(+), 80 deletions(-) create mode 100644 .pylintrc create mode 100644 docs/docs/Sketch.md diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..03fb003 --- /dev/null +++ b/.pylintrc @@ -0,0 +1 @@ +disable=line-too-long \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index af7446d..df01f63 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,4 +2,8 @@ "python.analysis.extraPaths": [ "./src" ] + // # show line numbers in jupyter notebooks + // "jupyter.lineNumbers": "on", + // # show line numbers in jup + } \ No newline at end of file diff --git a/docs/docs/Sketch.md b/docs/docs/Sketch.md new file mode 100644 index 0000000..effce2c --- /dev/null +++ b/docs/docs/Sketch.md @@ -0,0 +1,3 @@ +# Python API Documentation + +::: snipe.api.sketch \ No newline at end of file diff --git a/docs/docs/SnipeSig.md b/docs/docs/SnipeSig.md index d22ce15..6525f5a 100644 --- a/docs/docs/SnipeSig.md +++ b/docs/docs/SnipeSig.md @@ -1,3 +1,3 @@ # Python API Documentation -::: snipe.api \ No newline at end of file +::: snipe.api.snipe_sig \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2ad326f..c0281f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,10 @@ dependencies = [ "numpy", "pyfastx", "pathos", + "pandas", + "tqdm", + "lzstring", + "rapidfuzz", ] [project.optional-dependencies] diff --git a/src/snipe/api/reference_QC.py b/src/snipe/api/reference_QC.py index 3a7cc78..be43b27 100644 --- a/src/snipe/api/reference_QC.py +++ b/src/snipe/api/reference_QC.py @@ -5,8 +5,17 @@ import numpy as np from scipy.optimize import OptimizeWarning, curve_fit -from snipe.api import SnipeSig +from snipe.api.snipe_sig import SnipeSig from snipe.api.enums import SigType +import os +import requests +from tqdm import tqdm +import cgi +from urllib.parse import urlparse +from typing import Optional +import sourmash + +# pylint disable C0301 class ReferenceQC: @@ -245,6 +254,11 @@ def __init__(self, *, **kwargs): # Initialize logger self.logger = logging.getLogger(self.__class__.__name__) + + # Initialize split cache + self._split_cache: Dict[int, List[SnipeSig]] = {} + self.logger.debug("Initialized split cache.") + if enable_logging: self.logger.setLevel(logging.DEBUG) @@ -318,6 +332,8 @@ def __init__(self, *, self.advanced_stats: Dict[str, Any] = {} self.chrs_stats: Dict[str, Dict[str, Any]] = {} self.sex_stats: Dict[str, Any] = {} + self.predicted_error_contamination_index: Dict[str, Any] = {} + self.vars_nonref_stats: Dict[str, Any] = {} self.predicted_assay_type: str = "" # Set grey zone thresholds @@ -367,10 +383,10 @@ def _calculate_stats(self): # Get stats (call get_sample_stats only once) # Log hashes and abundances for both sample and reference - self.logger.debug("Sample hashes: %s", self.sample_sig.hashes) - self.logger.debug("Sample abundances: %s", self.sample_sig.abundances) - self.logger.debug("Reference hashes: %s", self.reference_sig.hashes) - self.logger.debug("Reference abundances: %s", self.reference_sig.abundances) + # self.logger.debug("Sample hashes: %s", self.sample_sig.hashes) + # self.logger.debug("Sample abundances: %s", self.sample_sig.abundances) + # self.logger.debug("Reference hashes: %s", self.reference_sig.hashes) + # self.logger.debug("Reference abundances: %s", self.reference_sig.abundances) sample_genome_stats = sample_genome.get_sample_stats @@ -411,7 +427,7 @@ def _calculate_stats(self): ), } - # Relative metrics + # ============= RELATIVE STATS ============= self.amplicon_stats["Relative total abundance"] = ( self.amplicon_stats["Amplicon k-mers total abundance"] / self.genome_stats["Genomic k-mers total abundance"] if self.genome_stats["Genomic k-mers total abundance"] > 0 else 0 @@ -421,7 +437,6 @@ def _calculate_stats(self): if self.genome_stats["Genome coverage index"] > 0 else 0 ) - # Predicted assay type relative_total_abundance = self.amplicon_stats["Relative total abundance"] if relative_total_abundance <= self.relative_total_abundance_grey_zone[0]: self.predicted_assay_type = "WGS" @@ -432,7 +447,23 @@ def _calculate_stats(self): distance_to_wgs = abs(relative_total_abundance - self.relative_total_abundance_grey_zone[0]) distance_to_wxs = abs(relative_total_abundance - self.relative_total_abundance_grey_zone[1]) self.predicted_assay_type = "WGS" if distance_to_wgs < distance_to_wxs else "WXS" + + self.logger.debug("Predicted assay type: %s", self.predicted_assay_type) + + self.logger.debug("Calculuating error and contamination indices.") + sample_nonref = self.sample_sig - self.reference_sig + sample_nonref_singletons = sample_nonref.count_singletons() + sample_nonref_non_singletons = sample_nonref.total_abundance - sample_nonref_singletons + sample_total_abundance = self.sample_sig.total_abundance + + predicted_error_index = sample_nonref_singletons / sample_total_abundance + predicted_contamination_index = sample_nonref_non_singletons / sample_total_abundance + + # predict error and contamination index + self.predicted_error_contamination_index["Predicted contamination index"] = predicted_contamination_index + self.predicted_error_contamination_index["Sequencing errors index"] = predicted_error_index + def get_aggregated_stats(self, include_advanced: bool = False) -> Dict[str, Any]: r""" @@ -469,11 +500,17 @@ def get_aggregated_stats(self, include_advanced: bool = False) -> Dict[str, Any] if self.sex_stats: aggregated_stats.update(self.sex_stats) + + if self.vars_nonref_stats: + aggregated_stats.update(self.vars_nonref_stats) # Include advanced_stats if requested if include_advanced: self._calculate_advanced_stats() aggregated_stats.update(self.advanced_stats) + + if self.predicted_error_contamination_index: + aggregated_stats.update(self.predicted_error_contamination_index) return aggregated_stats @@ -646,7 +683,15 @@ def split_sig_randomly(self, n: int) -> List[SnipeSig]: print(f"Signature part {idx}: {sig}") ``` """ - self.logger.debug("Splitting sample signature into %d random parts.", n) + self.logger.debug("Attempting to split sample signature into %d random parts.", n) + + # Check if the split for this n is already cached + if n in self._split_cache: + self.logger.debug("Using cached split signatures for n=%d.", n) + # Return deep copies to prevent external modifications + return [sig.copy() for sig in self._split_cache[n]] + + self.logger.debug("No cached splits found for n=%d. Proceeding to split.", n) # Get k-mers and abundances hash_to_abund = dict(zip(self.sample_sig.hashes, self.sample_sig.abundances)) random_split_sigs = self.distribute_kmers_random(hash_to_abund, n) @@ -662,6 +707,11 @@ def split_sig_randomly(self, n: int) -> List[SnipeSig]: ) for i, kmer_dict in enumerate(random_split_sigs) ] + + # Cache the split signatures + self._split_cache[n] = split_sigs + self.logger.debug("Cached split signatures for n=%d.", n) + return split_sigs @staticmethod @@ -767,15 +817,19 @@ def calculate_coverage_vs_depth(self, n: int = 30) -> List[Dict[str, Any]]: roi_reference_sig = self.reference_sig self.logger.debug("Using reference genome signature as ROI reference.") - # Split the sample signature into n random parts + # Split the sample signature into n random parts (cached if available) split_sigs = self.split_sig_randomly(n) coverage_depth_data = [] + if not split_sigs: + self.logger.error("No split signatures available. Cannot calculate coverage vs depth.") + return coverage_depth_data + cumulative_snipe_sig = split_sigs[0].copy() cumulative_total_abundance = cumulative_snipe_sig.total_abundance - #! force conversion to GENOME + # Force conversion to GENOME roi_reference_sig.sigtype = SigType.GENOME # Compute initial coverage index @@ -785,7 +839,7 @@ def calculate_coverage_vs_depth(self, n: int = 30) -> List[Dict[str, Any]]: enable_logging=self.enable_logging ) cumulative_stats = cumulative_qc.get_aggregated_stats() - cumulative_coverage_index = cumulative_stats["Genome coverage index"] + cumulative_coverage_index = cumulative_stats.get("Genome coverage index", 0.0) coverage_depth_data.append({ "cumulative_parts": 1, @@ -793,6 +847,8 @@ def calculate_coverage_vs_depth(self, n: int = 30) -> List[Dict[str, Any]]: "cumulative_coverage_index": cumulative_coverage_index, }) + self.logger.debug("Added initial coverage depth data for part 1.") + # Iterate over the rest of the parts for i in range(1, n): current_part = split_sigs[i] @@ -808,7 +864,7 @@ def calculate_coverage_vs_depth(self, n: int = 30) -> List[Dict[str, Any]]: enable_logging=self.enable_logging ) cumulative_stats = cumulative_qc.get_aggregated_stats() - cumulative_coverage_index = cumulative_stats["Genome coverage index"] + cumulative_coverage_index = cumulative_stats.get("Genome coverage index", 0.0) coverage_depth_data.append({ "cumulative_parts": i + 1, @@ -816,6 +872,8 @@ def calculate_coverage_vs_depth(self, n: int = 30) -> List[Dict[str, Any]]: "cumulative_coverage_index": cumulative_coverage_index, }) + self.logger.debug("Added coverage depth data for part %d.", i + 1) + self.logger.debug("Coverage vs depth calculation completed.") return coverage_depth_data @@ -1034,17 +1092,22 @@ def calculate_chromosome_metrics(self, chr_to_sig: Dict[str, SnipeSig]) -> Dict[ # Implementation of the method # let's make sure all chromosome sigs are unique + self.logger.debug("Computing specific chromosome hashes for %s.", ','.join(chr_to_sig.keys())) + self.logger.debug(f"\t-All hashes for chromosomes before getting unique sigs {len(SnipeSig.sum_signatures(list(chr_to_sig.values())))}") specific_chr_to_sig = SnipeSig.get_unique_signatures(chr_to_sig) + self.logger.debug(f"\t-All hashes for chromosomes after getting unique sigs {len(SnipeSig.sum_signatures(list(specific_chr_to_sig.values())))}") # calculate mean abundance for each chromosome and loaded sample sig chr_to_mean_abundance = {} self.logger.debug("Calculating mean abundance for each chromosome.") for chr_name, chr_sig in specific_chr_to_sig.items(): + self.logger.debug("Intersecting %s (%d) with %s (%d)", self.sample_sig.name, len(self.sample_sig), chr_name, len(chr_sig)) chr_sample_sig = self.sample_sig & chr_sig chr_stats = chr_sample_sig.get_sample_stats chr_to_mean_abundance[chr_name] = chr_stats["mean_abundance"] self.logger.debug("\t-Mean abundance for %s: %f", chr_name, chr_stats["mean_abundance"]) - + + self.chrs_stats.update(chr_to_mean_abundance) # chr_to_mean_abundance but without any chr with partial name sex autosomal_chr_to_mean_abundance = {} @@ -1266,5 +1329,379 @@ def calculate_sex_chrs_metrics(self, genome_and_chr_to_sig: Dict[str, SnipeSig]) self.logger.debug("Calculated Y-Coverage: %.4f", ycoverage) self.sex_stats.update({"Y-Coverage": ycoverage}) + + return self.sex_stats + + + + def nonref_consume_from_vars(self, *, vars: Dict[str, SnipeSig], vars_order: List[str], **kwargs) -> Dict[str, float]: + r""" + Consume and analyze non-reference k-mers from provided variable signatures. + + This method processes non-reference k-mers in the sample signature by intersecting them with a set of + variable-specific `SnipeSig` instances. It calculates coverage and total abundance metrics for each + variable in a specified order, ensuring that each non-reference k-mer is accounted for without overlap + between variables. The method updates internal statistics that reflect the distribution of non-reference + k-mers across the provided variables. + + **Process Overview**: + + 1. **Validation**: + - Verifies that all variable names specified in `vars_order` are present in the `vars` dictionary. + - Raises a `ValueError` if any variable in `vars_order` is missing from `vars`. + + 2. **Non-Reference K-mer Extraction**: + - Computes the set of non-reference non-singleton k-mers by subtracting the reference signature from the sample signature. + - If no non-reference k-mers are found, the method logs a warning and returns an empty dictionary. + + 3. **Variable-wise Consumption**: + - Iterates over each variable name in `vars_order`. + - For each variable: + - Intersects the remaining non-reference k-mers with the variable-specific signature. + - Calculates the total abundance and coverage index for the intersected k-mers. + - Updates the `vars_nonref_stats` dictionary with the computed metrics. + - Removes the consumed k-mers from the remaining non-reference set to prevent overlap. + + 4. **Final State Logging**: + - Logs the final size and total abundance of the remaining non-reference k-mers after consumption. + + **Parameters**: + + - `vars` (`Dict[str, SnipeSig]`): + A dictionary mapping variable names to their corresponding `SnipeSig` instances. Each `SnipeSig` + represents a set of k-mers associated with a specific non-reference category or variable. + + - `vars_order` (`List[str]`): + A list specifying the order in which variables should be processed. The order determines the priority + of consumption, ensuring that earlier variables in the list have their k-mers accounted for before + later ones. + + - `**kwargs`: + Additional keyword arguments. Reserved for future extensions and should not be used in the current context. + + **Returns**: + + - `Dict[str, float]`: + A dictionary containing statistics for each variable name in `vars_order`, + - `"non-genomic total k-mer abundance"` (`float`): + The sum of abundances of non-reference k-mers associated with the variable. + - `"non-genomic coverage index"` (`float`): + The ratio of unique non-reference k-mers associated with the variable to the total number + of non-reference k-mers in the sample before consumption. + + Example Output: + ```python + { + "variable_A non-genomic total k-mer abundance": 1500.0, + "variable_A non-genomic coverage index": 0.20 + "variable_B non-genomic total k-mer abundance": 3500.0, + "variable_B non-genomic coverage index": 0.70 + "non-var non-genomic total k-mer abundance": 0.10, + "non-var non-genomic coverage index": 218 + } + ``` + + **Raises**: + + - `ValueError`: + - If any variable specified in `vars_order` is not present in the `vars` dictionary. + - This ensures that all variables intended for consumption are available for processing. + + **Usage Example**: + + ```python + # Assume `variables_signatures` is a dictionary of variable-specific SnipeSig instances + variables_signatures = { + "GTDB": sig_GTDB, + "VIRALDB": sig_VIRALDB, + "contaminant_X": sig_contaminant_x + } + + # Define the order in which variables should be processed + processing_order = ["GTDB", "VIRALDB", "contaminant_X"] + + # Consume non-reference k-mers and retrieve statistics + nonref_stats = qc.nonref_consume_from_vars(vars=variables_signatures, vars_order=processing_order) + + print(nonref_stats) + # Output Example: + # { + # "GTDB non-genomic total k-mer abundance": 1500.0, + # "GTDB non-genomic coverage index": 0.2, + # "VIRALDB non-genomic total k-mer abundance": 3500.0, + # "VIRALDB non-genomic coverage index": 0.70, + # "contaminant_X non-genomic total k-mer abundance": 0.0, + # "contaminant_X non-genomic coverage index": 0.0, + # "non-var non-genomic total k-mer abundance": 100.0, + # "non-var non-genomic coverage index": 0.1 + # } + ``` + + **Notes**: + + - **Variable Processing Order**: + The `vars_order` list determines the sequence in which variables are processed. This order is crucial + when there is potential overlap in k-mers between variables, as earlier variables in the list have + higher priority in consuming shared k-mers. + + - **Non-Reference K-mers Definition**: + Non-reference k-mers are defined as those present in the sample signature but absent in the reference + signature. This method focuses on characterizing these unique k-mers relative to provided variables. + """ + + # check the all vars in vars_order are in vars + if not all([var in vars for var in vars_order]): + # report dict keys, and the vars order + self.logger.debug("Provided vars_order: %s, and vars keys: %s", vars_order, list(vars.keys())) + self.logger.error("All variables in vars_order must be present in vars.") + raise ValueError("All variables in vars_order must be present in vars.") + + self.logger.debug("Consuming non-reference k-mers from provided variables.") + self.logger.debug("\t-Current size of the sample signature: %d hashes.", len(self.sample_sig)) + + sample_nonref = self.sample_sig - self.reference_sig + + sample_nonref.trim_singletons() + + sample_nonref_unique_hashes = len(sample_nonref) + + self.logger.debug("\t-Size of non-reference k-mers in the sample signature: %d hashes.", len(sample_nonref)) + if len(sample_nonref) == 0: + self.logger.warning("No non-reference k-mers found in the sample signature.") + return {} + + # intersect and report coverage and depth, then subtract from sample_nonref so sum will be 100% + for var_name in vars_order: + sample_nonref_var: SnipeSig = sample_nonref & vars[var_name] + sample_nonref_var_total_abundance = sample_nonref_var.total_abundance + sample_nonref_var_unique_hashes = len(sample_nonref_var) + sample_nonref_var_coverage_index = sample_nonref_var_unique_hashes / sample_nonref_unique_hashes + self.vars_nonref_stats.update({ + f"{var_name} non-genomic total k-mer abundance": sample_nonref_var_total_abundance, + f"{var_name} non-genomic coverage index": sample_nonref_var_coverage_index + }) + + self.logger.debug("\t-Consuming non-reference k-mers from variable '%s'.", var_name) + sample_nonref -= sample_nonref_var + self.logger.debug("\t-Size of remaining non-reference k-mers in the sample signature: %d hashes.", len(sample_nonref)) + + self.vars_nonref_stats["non-var non-genomic total k-mer abundance"] = sample_nonref.total_abundance + self.vars_nonref_stats["non-var non-genomic coverage index"] = len(sample_nonref) / sample_nonref_unique_hashes if sample_nonref_unique_hashes > 0 else 0.0 + + self.logger.debug( + "After consuming all vars from the non reference k-mers, the size of the sample signature is: %d hashes, " + "with total abundance of %s.", + len(sample_nonref), sample_nonref.total_abundance + ) + + return self.vars_nonref_stats + + def load_genome_sig_to_dict(self, *, zip_file_path: str, **kwargs) -> Dict[str, 'SnipeSig']: + """ + Load a genome signature into a dictionary of SnipeSig instances. + + Parameters: + zip_file_path (str): Path to the zip file containing the genome signatures. + **kwargs: Additional keyword arguments to pass to the SnipeSig constructor. + + Returns: + Dict[str, SnipeSig]: A dictionary mapping genome names to SnipeSig instances. + """ - return self.sex_stats \ No newline at end of file + genome_chr_name_to_sig = {} + + sourmash_sigs: List[sourmash.signature.SourmashSignature] = sourmash.load_file_as_signatures(zip_file_path) + sex_count = 0 + autosome_count = 0 + genome_count = 0 + for sig in sourmash_sigs: + name = sig.name + if name.endswith("-snipegenome"): + self.logger.debug(f"Loading genome signature: {name}") + restored_name = name.replace("-snipegenome", "") + genome_chr_name_to_sig[restored_name] = SnipeSig(sourmash_sig=sig, sig_type=SigType.GENOME) + genome_count += 1 + elif "sex" in name: + sex_count += 1 + genome_chr_name_to_sig[name.replace('sex-','')] = SnipeSig(sourmash_sig=sig, sig_type=SigType.GENOME) + elif "autosome" in name: + autosome_count += 1 + genome_chr_name_to_sig[name.replace('autosome-','')] = SnipeSig(sourmash_sig=sig, sig_type=SigType.GENOME) + else: + logging.warning(f"Unknown genome signature name: {name}, are you sure you generated this with `snipe sketch --ref`?") + + self.logger.debug("Loaded %d genome signatures and %d sex chrs and %d autosome chrs", genome_count, sex_count, autosome_count) + + if genome_count != 1: + logging.error(f"Expected 1 genome signature, found {genome_count}") + + + return genome_chr_name_to_sig + + +class PreparedQC(ReferenceQC): + r""" + Class for quality control (QC) analysis of sample signature against prepared snipe profiles. + """ + + def __init__(self, *, sample_sig: SnipeSig, snipe_db_path: str = '~/.snipe/dbs/', ref_id: Optional[str] = None, amplicon_id: Optional[str] = None, enable_logging: bool = False, **kwargs): + """ + Initialize the PreparedQC instance. + + **Parameters** + + - `sample_sig` (`SnipeSig`): The sample k-mer signature. + - `snipe_db_path` (`str`): Path to the local Snipe database directory. + - `ref_id` (`Optional[str]`): Reference identifier for selecting specific profiles. + - `enable_logging` (`bool`): Flag to enable detailed logging. + - `**kwargs`: Additional keyword arguments. + """ + self.snipe_db_path = os.path.expanduser(snipe_db_path) + self.ref_id = ref_id + + # Ensure the local database directory exists + os.makedirs(self.snipe_db_path, exist_ok=True) + if enable_logging: + self.logger.debug(f"Local Snipe DB path set to: {self.snipe_db_path}") + else: + self.logger.debug("Logging is disabled for PreparedQC.") + + # Initialize without a reference signature for now; it can be set after downloading + super().__init__( + sample_sig=sample_sig, + reference_sig=None, # To be set after downloading + enable_logging=enable_logging, + **kwargs + ) + + def download_osf_db(self, url: str, save_path: str = '~/.snipe/dbs', force: bool = False) -> Optional[str]: + """ + Download a file from OSF using the provided URL. The file is saved with its original name + as specified by the OSF server via the Content-Disposition header. + + **Parameters** + + - `url` (`str`): The OSF URL to download the file from. + - `save_path` (`str`): The directory path where the file will be saved. Supports user (~) and environment variables. + Default is the local Snipe database directory. + - `force` (`bool`): If True, overwrite the file if it already exists. Default is False. + + **Returns** + + - `Optional[str]`: The path to the downloaded file if successful, else None. + + **Raises** + + - `requests.exceptions.RequestException`: If an error occurs during the HTTP request. + - `Exception`: For any other exceptions that may arise. + """ + try: + # Expand user (~) and environment variables in save_path + expanded_save_path = os.path.expanduser(os.path.expandvars(save_path)) + self.logger.debug(f"Expanded save path: {expanded_save_path}") + + # Ensure the download URL ends with '/download' + parsed_url = urlparse(url) + if not parsed_url.path.endswith('/download'): + download_url = f"{url.rstrip('/')}/download" + else: + download_url = url + + self.logger.debug(f"Download URL: {download_url}") + + # Ensure the save directory exists + os.makedirs(expanded_save_path, exist_ok=True) + self.logger.debug(f"Save path verified/created: {expanded_save_path}") + + # Initiate the GET request with streaming + with requests.get(download_url, stream=True, allow_redirects=True) as response: + response.raise_for_status() # Raise an exception for HTTP errors + + # Attempt to extract filename from Content-Disposition + content_disposition = response.headers.get('Content-Disposition') + filename = self._extract_filename(content_disposition, parsed_url.path) + self.logger.debug(f"Filename determined: {filename}") + + # Define the full save path + full_save_path = os.path.join(expanded_save_path, filename) + self.logger.debug(f"Full save path: {full_save_path}") + + # Check if the file already exists + if os.path.exists(full_save_path): + if force: + self.logger.info(f"Overwriting existing file: {full_save_path}") + else: + self.logger.info(f"File already exists: {full_save_path}. Skipping download.") + return full_save_path + + # Get the total file size for the progress bar + total_size = int(response.headers.get('Content-Length', 0)) + + # Initialize the progress bar + with open(full_save_path, 'wb') as file, tqdm( + total=total_size, + unit='B', + unit_scale=True, + unit_divisor=1024, + desc=filename, + ncols=100 + ) as bar: + for chunk in response.iter_content(chunk_size=1024): + if chunk: # Filter out keep-alive chunks + file.write(chunk) + bar.update(len(chunk)) + + self.logger.info(f"File downloaded successfully: {full_save_path}") + return full_save_path + + except requests.exceptions.RequestException as req_err: + self.logger.error(f"Request error occurred while downloading {url}: {req_err}") + raise + except Exception as e: + self.logger.error(f"An unexpected error occurred while downloading {url}: {e}") + raise + + def _extract_filename(self, content_disposition: Optional[str], url_path: str) -> str: + """ + Extract filename from Content-Disposition header or fallback to URL path. + + **Parameters** + + - `content_disposition` (`Optional[str]`): The Content-Disposition header value. + - `url_path` (`str`): The path component of the URL. + + **Returns** + + - `str`: The extracted filename. + """ + filename = None + if content_disposition: + self.logger.debug("Parsing Content-Disposition header for filename.") + parts = content_disposition.split(';') + for part in parts: + part = part.strip() + if part.lower().startswith('filename*='): + # Handle RFC 5987 encoding (e.g., filename*=UTF-8''example.txt) + encoded_filename = part.split('=', 1)[1].strip() + if "''" in encoded_filename: + filename = encoded_filename.split("''", 1)[1] + else: + filename = encoded_filename + self.logger.debug(f"Filename extracted from headers (RFC 5987): {filename}") + break + elif part.lower().startswith('filename='): + # Remove 'filename=' and any surrounding quotes + filename = part.split('=', 1)[1].strip(' "') + self.logger.debug(f"Filename extracted from headers: {filename}") + break + + if not filename: + self.logger.debug("Falling back to filename derived from URL path.") + filename = os.path.basename(url_path) + if not filename: + filename = 'downloaded_file' + self.logger.debug(f"Filename derived from URL: {filename}") + + return filename + + diff --git a/src/snipe/api/sketch.py b/src/snipe/api/sketch.py index 941da5f..838d3b0 100644 --- a/src/snipe/api/sketch.py +++ b/src/snipe/api/sketch.py @@ -8,12 +8,11 @@ import threading import queue from typing import Any, Dict, List, Optional, Tuple - -from pyfastx import Fastx as SequenceReader +from pyfastx import Fastx as SequenceReader # pylint: disable=no-name-in-module import sourmash from pathos.multiprocessing import ProcessingPool as Pool from snipe.api.enums import SigType -from snipe.api import SnipeSig +from snipe.api.snipe_sig import SnipeSig class SnipeSketch: diff --git a/src/snipe/api/snipe_sig.py b/src/snipe/api/snipe_sig.py index 0b5ee64..1ab149b 100644 --- a/src/snipe/api/snipe_sig.py +++ b/src/snipe/api/snipe_sig.py @@ -16,46 +16,6 @@ class SnipeSig: such as customized set operations and abundance management. """ - def _try_load_from_json(self, sourmash_sig: str) -> Union[List[sourmash.signature.SourmashSignature], None]: - r""" - Attempt to load sourmash signature from JSON string. - - Parameters: - sourmash_sig (str): JSON string representing a sourmash signature. - - Returns: - sourmash.signature.SourmashSignature or None if loading fails. - """ - try: - self.logger.debug("Trying to load sourmash signature from JSON.") - list_of_sigs = list(sourmash.load_signatures_from_json(sourmash_sig)) - return {sig.name: sig for sig in list_of_sigs} - except Exception as e: - self.logger.debug("Loading from JSON failed. Proceeding to file loading.", exc_info=e) - return None # Return None to indicate failure - - def _try_load_from_file(self, sourmash_sig_path: str) -> Union[List[sourmash.signature.SourmashSignature], None]: - r""" - Attempt to load sourmash signature(s) from a file. - - Parameters: - sourmash_sig_path (str): File path to a sourmash signature. - - Returns: - sourmash.signature.SourmashßSignature, list of sourmash.signature.SourmashSignature, or None if loading fails. - """ - self.logger.debug("Trying to load sourmash signature from file.") - try: - signatures = list(sourmash.load_file_as_signatures(sourmash_sig_path)) - self.logger.debug("Loaded %d sourmash signature(s) from file.", len(signatures)) - sigs_dict = {_sig.name: _sig for _sig in signatures} - self.logger.debug("Loaded sourmash signatures into sigs_dict: %s", sigs_dict) - return sigs_dict - except Exception as e: - self.logger.exception("Failed to load the sourmash signature from the file.", exc_info=e) - raise ValueError("An unexpected error occurred while loading the sourmash signature.") from e - - def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature], ksize: int = 51, scale: int = 10000, sig_type=SigType.SAMPLE, enable_logging: bool = False, **kwargs): r""" @@ -104,15 +64,12 @@ def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignat sourmash_sigs: Dict[str, sourmash.signature.SourmashSignature] = {} _sourmash_sig: Union[sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature] = None + + self.chr_to_sig: Dict[str, SnipeSig] = {} self.logger.debug("Proceeding with a sigtype of %s", sig_type) - - - - - if not isinstance(sourmash_sig, (str, sourmash.signature.SourmashSignature, sourmash.signature.FrozenSourmashSignature)): # if the str is not a file path self.logger.error("Invalid type for sourmash_sig: %s", type(sourmash_sig).__name__) @@ -155,15 +112,28 @@ def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignat elif sig_type == SigType.GENOME: if len(sourmash_sigs) > 1: for signame, sig in sourmash_sigs.items(): + self.logger.debug(f"Iterating over signature: {signame}") if signame.endswith("-snipegenome"): sig = sig.to_mutable() sig.name = sig.name.replace("-snipegenome", "") - self.logger.debug("Found a genome signature with a snipe modified name. Restoring original name `%s`.", sig.name) + self.logger.debug("Found a genome signature with the snipe suffix `-snipegenome`. Restoring original name `%s`.", sig.name) _sourmash_sig = sig - break + elif signame.startswith("sex-"): + self.logger.debug("Found a sex chr signature %s", signame) + sig = sig.to_mutable() + sig.name = signame.replace("sex-","") + self.chr_to_sig[sig.name] = SnipeSig(sourmash_sig=sig, sig_type=SigType.AMPLICON, enable_logging=enable_logging) + elif signame.startswith("autosome-"): + self.logger.debug("Found an autosome signature %s", signame) + sig = sig.to_mutable() + sig.name = signame.replace("autosome-","") + self.chr_to_sig[sig.name] = SnipeSig(sourmash_sig=sig, sig_type=SigType.AMPLICON, enable_logging=enable_logging) + else: + continue else: - self.logger.debug("Found multiple signature per the genome file, but none with a snipe modified name.") - raise ValueError("Found multiple signature per the genome file, but none with a snipe modified name.") + if not _sourmash_sig: + self.logger.debug("Found multiple signature per the genome file, but none with the snipe suffix `-snipegenome`.") + raise ValueError("Found multiple signature per the genome file, but none with the snipe suffix `-snipegenome`.") elif len(sourmash_sigs) == 1: self.logger.debug("Found a single signature in the genome sig input; Will use this signature.") _sourmash_sig = list(sourmash_sigs.values())[0] @@ -205,6 +175,45 @@ def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignat self.logger.debug("Hashes sorted during initialization.") self.logger.debug("Sourmash signature loading completed successfully.") + def _try_load_from_json(self, sourmash_sig: str) -> Union[List[sourmash.signature.SourmashSignature], None]: + r""" + Attempt to load sourmash signature from JSON string. + + Parameters: + sourmash_sig (str): JSON string representing a sourmash signature. + + Returns: + sourmash.signature.SourmashSignature or None if loading fails. + """ + try: + self.logger.debug("Trying to load sourmash signature from JSON.") + list_of_sigs = list(sourmash.load_signatures_from_json(sourmash_sig)) + return {sig.name: sig for sig in list_of_sigs} + except Exception as e: + self.logger.debug("Loading from JSON failed. Proceeding to file loading.", exc_info=e) + return None # Return None to indicate failure + + def _try_load_from_file(self, sourmash_sig_path: str) -> Union[List[sourmash.signature.SourmashSignature], None]: + r""" + Attempt to load sourmash signature(s) from a file. + + Parameters: + sourmash_sig_path (str): File path to a sourmash signature. + + Returns: + sourmash.signature.SourmashßSignature, list of sourmash.signature.SourmashSignature, or None if loading fails. + """ + self.logger.debug("Trying to load sourmash signature from file.") + try: + signatures = list(sourmash.load_file_as_signatures(sourmash_sig_path)) + self.logger.debug("Loaded %d sourmash signature(s) from file.", len(signatures)) + sigs_dict = {_sig.name: _sig for _sig in signatures} + self.logger.debug("Loaded sourmash signatures into sigs_dict: %s", sigs_dict) + return sigs_dict + except Exception as e: + self.logger.exception("Failed to load the sourmash signature from the file.", exc_info=e) + raise ValueError("An unexpected error occurred while loading the sourmash signature.") from e + # Setters and getters @property def hashes(self) -> np.ndarray: diff --git a/src/snipe/cli/main.py b/src/snipe/cli/main.py index 132be45..640f8c0 100644 --- a/src/snipe/cli/main.py +++ b/src/snipe/cli/main.py @@ -48,7 +48,7 @@ def cli(): @click.option('-s', '--sample', type=click.Path(exists=True), help='Sample FASTA file.') @click.option('-r', '--ref', type=click.Path(exists=True), help='Reference genome FASTA file.') @click.option('-a', '--amplicon', type=click.Path(exists=True), help='Amplicon FASTA file.') -@click.option('--ychr', type=click.Path(exists=True), help='Y chromosome signature file (required for --ref and --amplicon).') +@click.option('--ychr', type=click.Path(exists=True), help='Y chromosome FASTA file (overrides the reference ychr).') @click.option('-n', '--name', required=True, help='Signature name.') @click.option('-o', '--output-file', required=True, callback=validate_zip_file, help='Output file with .zip extension.') @click.option('-b', '--batch-size', default=100000, type=int, show_default=True, help='Batch size for sample sketching.') @@ -210,7 +210,9 @@ def sketch(ctx, sample: Optional[str], ref: Optional[str], amplicon: Optional[st # Define the top-level process_sample function def process_sample(sample_path: str, ref_path: str, amplicon_path: Optional[str], - advanced: bool, roi: bool, debug: bool) -> Dict[str, Any]: + advanced: bool, roi: bool, debug: bool, + ychr: Optional[str] = None, + vars_paths: Optional[List[str]] = None) -> Dict[str, Any]: """ Process a single sample for QC. @@ -221,6 +223,7 @@ def process_sample(sample_path: str, ref_path: str, amplicon_path: Optional[str] - advanced (bool): Flag to include advanced metrics. - roi (bool): Flag to calculate ROI. - debug (bool): Flag to enable debugging. + - vars_paths (Optional[List[str]]): List of variable signature file paths. Returns: - Dict[str, Any]: QC results for the sample. @@ -259,8 +262,31 @@ def process_sample(sample_path: str, ref_path: str, amplicon_path: Optional[str] enable_logging=debug ) - # calculate chromosome metrics - qc_instance.calculate_chromosome_metrics() + # Load variable signatures if provided + if vars_paths: + qc_instance.logger.debug(f"Loading {len(vars_paths)} variable signature(s).") + vars_dict: Dict[str, SnipeSig] = {} + vars_order: List[str] = [] + for path in vars_paths: + qc_instance.logger.debug(f"Loading variable signature from: {path}") + var_sig = SnipeSig(sourmash_sig=path, sig_type=SigType.AMPLICON, enable_logging=debug) + var_name = var_sig.name if var_sig.name else os.path.basename(path) + qc_instance.logger.debug(f"Loaded variable signature: {var_name}") + vars_dict[var_name] = var_sig + vars_order.append(var_name) + logger.debug(f"Loaded variable signature '{var_name}': {var_sig.name}") + # Pass variables to ReferenceQC + qc_instance.nonref_consume_from_vars(vars=vars_dict, vars_order=vars_order) + # No else block needed; variables are optional + + # Calculate chromosome metrics + # genome_chr_to_sig: Dict[str, SnipeSig] = qc_instance.load_genome_sig_to_dict(zip_file_path = ref_path) + chr_to_sig = reference_sig.chr_to_sig.copy() + if ychr: + ychr_sig = SnipeSig(sourmash_sig=ychr, sig_type=SigType.GENOME, enable_logging=debug) + chr_to_sig['y'] = ychr_sig + + qc_instance.calculate_chromosome_metrics(chr_to_sig) # Get aggregated stats aggregated_stats = qc_instance.get_aggregated_stats(include_advanced=advanced) @@ -304,11 +330,13 @@ def process_sample(sample_path: str, ref_path: str, amplicon_path: Optional[str] @click.option('--roi', is_flag=True, default=False, help='Calculate ROI for 1,2,5,9 folds.') @click.option('--cores', '-c', default=4, type=int, show_default=True, help='Number of CPU cores to use for parallel processing.') @click.option('--advanced', is_flag=True, default=False, help='Include advanced QC metrics.') +@click.option('--ychr', type=click.Path(exists=True), help='Y chromosome signature file (overrides the reference ychr).') @click.option('--debug', is_flag=True, default=False, help='Enable debugging and detailed logging.') @click.option('-o', '--output', required=True, callback=validate_tsv_file, help='Output TSV file for QC results.') +@click.option('--var', 'vars', multiple=True, type=click.Path(exists=True), help='Variable signature file path. Can be used multiple times.') def qc(ref: str, sample: List[str], samples_from_file: Optional[str], - amplicon: Optional[str], roi: bool, cores: int, advanced: bool, - debug: bool, output: str): + amplicon: Optional[str], roi: bool, cores: int, advanced: bool, + ychr: Optional[str], debug: bool, output: str, vars: List[str]): """ Perform quality control (QC) on multiple samples against a reference genome. @@ -377,18 +405,38 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str], logger.error(f"Failed to load amplicon signature from {amplicon}: {e}") sys.exit(1) - # Prepare arguments for parallel processing - process_args = [ - (sample_path, ref, amplicon, advanced, roi, debug) - for sample_path in valid_samples - ] - + # Prepare variable signatures if provided + vars_paths = [] + if vars: + logger.info(f"Loading {len(vars)} variable signature(s).") + for path in vars: + if not os.path.exists(path): + logger.error(f"Variable signature file does not exist: {path}") + sys.exit(1) + vars_paths.append(os.path.abspath(path)) + logger.debug(f"Variable signature paths: {vars_paths}") + + # Prepare arguments for process_sample function + dict_process_args = [] + for sample_path in valid_samples: + dict_process_args.append({ + "sample_path": sample_path, + "ref_path": ref, + "amplicon_path": amplicon, + "advanced": advanced, + "roi": roi, + "debug": debug, + "ychr": ychr, + "vars_paths": vars_paths + }) + + # Process samples in parallel with progress bar results = [] with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor: # Submit all tasks futures = { - executor.submit(process_sample, *args): args[0] for args in process_args + executor.submit(process_sample, **args): args for args in dict_process_args } # Iterate over completed futures with a progress bar for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing samples"): From 2f8b29d4b4f347205061b610c1103613a1e0ec8c Mon Sep 17 00:00:00 2001 From: Mohamed Abuelanin Date: Mon, 14 Oct 2024 17:20:05 -0700 Subject: [PATCH 7/7] handling empty signatures --- src/snipe/api/reference_QC.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/snipe/api/reference_QC.py b/src/snipe/api/reference_QC.py index be43b27..9e76423 100644 --- a/src/snipe/api/reference_QC.py +++ b/src/snipe/api/reference_QC.py @@ -452,17 +452,21 @@ def _calculate_stats(self): self.logger.debug("Predicted assay type: %s", self.predicted_assay_type) self.logger.debug("Calculuating error and contamination indices.") - sample_nonref = self.sample_sig - self.reference_sig - sample_nonref_singletons = sample_nonref.count_singletons() - sample_nonref_non_singletons = sample_nonref.total_abundance - sample_nonref_singletons - sample_total_abundance = self.sample_sig.total_abundance - - predicted_error_index = sample_nonref_singletons / sample_total_abundance - predicted_contamination_index = sample_nonref_non_singletons / sample_total_abundance - - # predict error and contamination index - self.predicted_error_contamination_index["Predicted contamination index"] = predicted_contamination_index - self.predicted_error_contamination_index["Sequencing errors index"] = predicted_error_index + try: + sample_nonref = self.sample_sig - self.reference_sig + sample_nonref_singletons = sample_nonref.count_singletons() + sample_nonref_non_singletons = sample_nonref.total_abundance - sample_nonref_singletons + sample_total_abundance = self.sample_sig.total_abundance + + predicted_error_index = sample_nonref_singletons / sample_total_abundance + predicted_contamination_index = sample_nonref_non_singletons / sample_total_abundance + + # predict error and contamination index + self.predicted_error_contamination_index["Predicted contamination index"] = predicted_contamination_index + self.predicted_error_contamination_index["Sequencing errors index"] = predicted_error_index + # except zero division error + except ZeroDivisionError: + self.logger.error("Please check the sample signature, it seems to be empty.") def get_aggregated_stats(self, include_advanced: bool = False) -> Dict[str, Any]: