Skip to content

Commit

Permalink
Merge pull request #12 from snipe-bio/minor_fixes_qc
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-eyes authored Oct 15, 2024
2 parents 3082e02 + a46d927 commit 2675856
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 38 deletions.
22 changes: 18 additions & 4 deletions src/snipe/api/reference_QC.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import os
import requests
from tqdm import tqdm
import cgi
from urllib.parse import urlparse
from typing import Optional
import sourmash
Expand Down Expand Up @@ -1027,7 +1026,7 @@ def saturation_model(x, a, b):
predicted_coverage = min(predicted_coverage, max_coverage)

self.logger.debug("Predicted coverage at %.2f-fold increase: %f", extra_fold, predicted_coverage)
return predicted_coverage
return float(predicted_coverage)

def calculate_chromosome_metrics(self, chr_to_sig: Dict[str, SnipeSig]) -> Dict[str, Any]:
r"""
Expand Down Expand Up @@ -1110,8 +1109,23 @@ def calculate_chromosome_metrics(self, chr_to_sig: Dict[str, SnipeSig]) -> Dict[
chr_stats = chr_sample_sig.get_sample_stats
chr_to_mean_abundance[chr_name] = chr_stats["mean_abundance"]
self.logger.debug("\t-Mean abundance for %s: %f", chr_name, chr_stats["mean_abundance"])

self.chrs_stats.update(chr_to_mean_abundance)

# chromosomes are numberd from 1 to ..., sort them by numer (might be string for sex chromosomes) then prefix them with chr-
def sort_chromosomes(chr_name):
try:
# Try to convert to integer for numeric chromosomes
return (0, int(chr_name))
except ValueError:
# Non-numeric chromosomes (like 'x', 'y', 'z', etc.)
return (1, chr_name)

# Create a new dictionary with sorted chromosome names and prefixed with 'chr-'
sorted_chr_to_mean_abundance = {
f"chr-{chr_name}": chr_to_mean_abundance[chr_name]
for chr_name in sorted(chr_to_mean_abundance, key=sort_chromosomes)
}

self.chrs_stats.update(sorted_chr_to_mean_abundance)

# chr_to_mean_abundance but without any chr with partial name sex
autosomal_chr_to_mean_abundance = {}
Expand Down
37 changes: 31 additions & 6 deletions src/snipe/api/snipe_sig.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import heapq
import logging

import sourmash.save_load
from snipe.api.enums import SigType
from typing import Any, Dict, Iterator, List, Optional, Union
import numpy as np
import sourmash

import os

# Configure the root logger to CRITICAL to suppress unwanted logs by default
logging.basicConfig(level=logging.CRITICAL)
Expand Down Expand Up @@ -150,6 +152,13 @@ def __init__(self, *, sourmash_sig: Union[str, sourmash.signature.SourmashSignat
self._name = _sourmash_sig.name
self._filename = _sourmash_sig.filename
self._track_abundance = _sourmash_sig.minhash.track_abundance

if self._name.endswith("-snipesample"):
self._name = self._name.replace("-snipesample", "")
self.logger.debug("Found a sample signature with the snipe suffix `-snipesample`. Restoring original name `%s`.", self._name)
elif self._name.endswith("-snipeamplicon"):
self._name = self._name.replace("-snipeamplicon", "")
self.logger.debug("Found an amplicon signature with the snipe suffix `-snipeamplicon`. Restoring original name `%s`.", self._name)

# If the signature does not track abundance, assume abundance of 1 for all hashes
if not self._track_abundance:
Expand Down Expand Up @@ -485,16 +494,34 @@ def _convert_to_sourmash_signature(self):
self.sourmash_sig = sourmash.signature.SourmashSignature(mh, name=self._name, filename=self._filename)
self.logger.debug("Conversion to sourmash.signature.SourmashSignature completed.")

def export(self, path) -> None:
def export(self, path, force=False) -> None:
r"""
Export the signature to a file.
Parameters:
path (str): The path to save the signature to.
force (bool): Flag to overwrite the file if it already exists.
"""
self._convert_to_sourmash_signature()
with open(str(path), "wb") as fp:
sourmash.signature.save_signatures_to_json([self.sourmash_sig], fp)
if path.endswith(".sig"):
self.logger.debug("Exporting signature to a .sig file.")
with open(str(path), "wb") as fp:
sourmash.signature.save_signatures_to_json([self.sourmash_sig], fp)
# sourmash.save_load.SaveSignatures_SigFile

elif path.endswith(".zip"):
if os.path.exists(path):
raise FileExistsError("Output file already exists.")
try:
with sourmash.save_load.SaveSignatures_ZipFile(path) as save_sigs:
save_sigs.add(self.sourmash_sig)
except Exception as e:
logging.error("Failed to export signatures to zip: %s", e)
raise Exception(f"Failed to export signatures to zip: {e}") from e
else:
raise ValueError("Output file must be either a .sig or .zip file.")



def export_to_string(self):
r"""
Expand Down Expand Up @@ -924,8 +951,6 @@ def sum_signatures(cls, signatures: List['SnipeSig'], name: str = "summed_signat
for sig in signatures[1:]:
if sig.ksize != ksize or sig.scale != scale:
raise ValueError("All signatures must have the same ksize and scale.")
if sig.track_abundance != track_abundance:
raise ValueError("All signatures must have the same track_abundance setting.")

# Initialize iterators for each signature's hashes and abundances
iterators = []
Expand Down
93 changes: 66 additions & 27 deletions src/snipe/cli/cli_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from snipe.api.snipe_sig import SnipeSig
from snipe.api.reference_QC import ReferenceQC

import concurrent.futures
import signal

def process_sample(sample_path: str, ref_path: str, amplicon_path: Optional[str],
advanced: bool, roi: bool, debug: bool,
Expand Down Expand Up @@ -426,7 +428,7 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str],
if samples_from_file:
logger.debug(f"Reading samples from file: {samples_from_file}")
try:
with open(samples_from_file, encoding='utf-8') as f:
with open(samples_from_file, 'r', encoding='utf-8') as f:
file_samples = {line.strip() for line in f if line.strip()}
samples_set.update(file_samples)
logger.debug(f"Collected {len(file_samples)} samples from file.")
Expand Down Expand Up @@ -492,32 +494,58 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str],
"ychr": ychr,
"vars_paths": vars_paths
})

results = []

# Define a handler for graceful shutdown
def shutdown(signum, frame):
logger.warning("Shutdown signal received. Terminating all worker processes...")
executor.shutdown(wait=False, cancel_futures=True)
sys.exit(1)

# Register signal handlers
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)

# Process samples in parallel with progress bar
results = []
with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor:
# Submit all tasks
futures = {
executor.submit(process_sample, **args): args for args in dict_process_args
}
# Iterate over completed futures with a progress bar
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing samples"):
sample = futures[future]
try:
result = future.result()
results.append(result)
except Exception as exc:
logger.error(f"Sample {sample} generated an exception: {exc}")
results.append({
"sample": os.path.splitext(os.path.basename(sample))[0],
"file_path": sample,
"QC_Error": str(exc)
})

# Create pandas DataFrame
logger.info("Aggregating results into DataFrame.")
df = pd.DataFrame(results)
try:
with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor:
futures = {
executor.submit(process_sample, **args): args for args in dict_process_args
}

for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing samples"):
sample = futures[future]
try:
result = future.result()
results.append(result)
except Exception as exc:
logger.error(f"Sample {sample['sample_path']} generated an exception: {exc}")
results.append({
"sample": os.path.splitext(os.path.basename(sample['sample_path']))[0],
"file_path": sample['sample_path'],
"QC_Error": str(exc)
})
except KeyboardInterrupt:
logger.warning("KeyboardInterrupt received. Shutting down...")
sys.exit(1)
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
sys.exit(1)

# Separate successful and failed results
succeeded = [res for res in results if "QC_Error" not in res]
failed = [res for res in results if "QC_Error" in res]

# Handle complete failure
if len(succeeded) == 0:
logger.error("All samples failed during QC processing. Output TSV will not be generated.")
sys.exit(1)

# Prepare the command-line invocation for comments
command_invocation = ' '.join(sys.argv)

# Create pandas DataFrame for succeeded samples
df = pd.DataFrame(succeeded)

# Reorder columns to have 'sample' and 'file_path' first, if they exist
cols = list(df.columns)
Expand All @@ -529,14 +557,25 @@ def qc(ref: str, sample: List[str], samples_from_file: Optional[str],
reordered_cols += cols
df = df[reordered_cols]

# Export to TSV
# Export to TSV with comments
try:
df.to_csv(output, sep='\t', index=False)
with open(output, 'w', encoding='utf-8') as f:
# Write comment with command invocation
f.write(f"# Command: {command_invocation}\n")
# Write the DataFrame to the file
df.to_csv(f, sep='\t', index=False)
logger.info(f"QC results successfully exported to {output}")
except Exception as e:
logger.error(f"Failed to export QC results to {output}: {e}")
sys.exit(1)

# Report failed samples if any
if failed:
failed_samples = [res['sample'] for res in failed]
logger.warning(f"The following {len(failed_samples)} sample(s) failed during QC processing:")
for sample in failed_samples:
logger.warning(f"- {sample}")

end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"QC process completed in {elapsed_time:.2f} seconds.")
2 changes: 1 addition & 1 deletion tests/api/test_reference_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_predict_coverage_already_full_coverage(self):
)
current_coverage = qc.genome_stats["Genome coverage index"]
self.assertEqual(current_coverage, 1.0)
predicted_coverage = qc.predict_coverage(extra_fold=1.0, n=10)
predicted_coverage = qc.predict_coverage(extra_fold=2.0)
# Predicted coverage should still be 1.0
self.assertEqual(predicted_coverage, 1.0)

Expand Down

0 comments on commit 2675856

Please sign in to comment.