From 46963625ad3f8fc2229589209beb9414f640b0d8 Mon Sep 17 00:00:00 2001 From: Roman Joeres Date: Fri, 15 Mar 2024 18:22:49 +0100 Subject: [PATCH] Extending handling of embedding at input plus more tests --- datasail/cluster/clustering.py | 2 +- datasail/cluster/ecfp.py | 15 +- datasail/cluster/tanimoto.py | 29 ---- datasail/cluster/vectors.py | 136 ++++++++++++++++++ datasail/reader/read_genomes.py | 49 ++----- datasail/reader/read_molecules.py | 57 +++----- datasail/reader/read_other.py | 25 +--- datasail/reader/read_proteins.py | 70 ++------- datasail/reader/utils.py | 74 +++++++++- tests/test_clustering.py | 57 ++++++-- ...st_clustom_args.py => test_custom_args.py} | 0 tests/test_pipeline.py | 48 +++++++ 12 files changed, 347 insertions(+), 215 deletions(-) delete mode 100644 datasail/cluster/tanimoto.py create mode 100644 datasail/cluster/vectors.py rename tests/{test_clustom_args.py => test_custom_args.py} (100%) diff --git a/datasail/cluster/clustering.py b/datasail/cluster/clustering.py index 1622977..66c4272 100644 --- a/datasail/cluster/clustering.py +++ b/datasail/cluster/clustering.py @@ -12,7 +12,7 @@ from datasail.cluster.mash import run_mash from datasail.cluster.mmseqs2 import run_mmseqs from datasail.cluster.mmseqspp import run_mmseqspp -from datasail.cluster.tanimoto import run_tanimoto +from datasail.cluster.vectors import run_tanimoto from datasail.cluster.utils import heatmap from datasail.cluster.wlk import run_wlk from datasail.reader.utils import DataSet diff --git a/datasail/cluster/ecfp.py b/datasail/cluster/ecfp.py index e36201e..b0e3331 100644 --- a/datasail/cluster/ecfp.py +++ b/datasail/cluster/ecfp.py @@ -1,20 +1,21 @@ -import numpy as np -from rdkit import Chem, DataStructs, RDLogger +from rdkit import Chem, RDLogger from rdkit.Chem import AllChem from rdkit.Chem.Scaffolds.MurckoScaffold import MakeScaffoldGeneric from rdkit.Chem.rdchem import MolSanitizeException from datasail.cluster.utils import read_molecule_encoding +from datasail.cluster.vectors import run, SIM_OPTIONS from datasail.reader.utils import DataSet from datasail.settings import LOGGER -def run_ecfp(dataset: DataSet) -> None: +def run_ecfp(dataset: DataSet, method: SIM_OPTIONS = "Tanimoto") -> None: """ Compute 1024Bit-ECPFs for every molecule in the dataset and then compute pairwise Tanimoto-Scores of them. Args: dataset: The dataset to compute pairwise, elementwise similarities for + method: The similarity measure to use. Default is "Tanimoto". """ lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) @@ -54,14 +55,8 @@ def run_ecfp(dataset: DataSet) -> None: fps.append(AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(scaffold), 2, nBits=1024)) LOGGER.info(f"Reduced {len(dataset.names)} molecules to {len(dataset.cluster_names)}") - LOGGER.info("Compute Tanimoto Coefficients") - count = len(dataset.cluster_names) - dataset.cluster_similarity = np.zeros((count, count)) - for i in range(count): - dataset.cluster_similarity[i, i] = 1 - dataset.cluster_similarity[i, :i] = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) - dataset.cluster_similarity[:i, i] = dataset.cluster_similarity[i, :i] + run(dataset, fps, method) dataset.cluster_map = dict((name, Chem.MolToSmiles(scaffolds[name])) for name in dataset.names) diff --git a/datasail/cluster/tanimoto.py b/datasail/cluster/tanimoto.py deleted file mode 100644 index 8b56be4..0000000 --- a/datasail/cluster/tanimoto.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np -from rdkit import DataStructs - -from datasail.reader.utils import DataSet -from datasail.settings import LOGGER - - -def run_tanimoto(dataset: DataSet) -> None: - """ - Compute pairwise Tanimoto-Scores of the given dataset. - - Args: - dataset: The dataset to compute pairwise, elementwise similarities for - """ - LOGGER.info("Start Tanimoto clustering") - - if not isinstance(list(dataset.data.values())[0], np.ndarray): - raise ValueError("Tanimoto-Clustering can only be applied to already computed embeddings.") - - count = len(dataset.cluster_names) - dataset.cluster_similarity = np.zeros((count, count)) - fps = [dataset.data[name] for name in dataset.names] - for i in range(count): - dataset.cluster_similarity[i, i] = 1 - dataset.cluster_similarity[i, :i] = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) - dataset.cluster_similarity[:i, i] = dataset.cluster_similarity[i, :i] - - dataset.cluster_names = dataset.names - dataset.cluster_map = dict((n, n) for n in dataset.names) diff --git a/datasail/cluster/vectors.py b/datasail/cluster/vectors.py new file mode 100644 index 0000000..b89f907 --- /dev/null +++ b/datasail/cluster/vectors.py @@ -0,0 +1,136 @@ +import copy +from typing import Literal + +import numpy as np +from rdkit import DataStructs + +from datasail.reader.utils import DataSet +from datasail.settings import LOGGER + +SIM_OPTIONS = Literal[ + "AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Dice", "Kulczynski", "McConnaughey", "OnBit", "RogotGoldberg", + "Russel", "Sokal", "Tanimoto", "Jaccard" +] + + +def get_rdkit_fct(method: SIM_OPTIONS): + """ + Get the RDKit function for the given similarity measure. + + Args: + method: The name of the similarity measure to get the function for. + + Returns: + The RDKit function for the given similarity measure. + """ + if method == "AllBit": + return DataStructs.BulkAllBitSimilarity + if method == "Asymmetric": + return DataStructs.BulkAsymmetricSimilarity + if method == "BraunBlanquet": + return DataStructs.BulkBraunBlanquetSimilarity + if method == "Cosine": + return DataStructs.BulkCosineSimilarity + if method == "Dice": + return DataStructs.BulkDiceSimilarity + if method == "Kulczynski": + return DataStructs.BulkKulczynskiSimilarity + if method == "McConnaughey": + return DataStructs.BulkMcConnaugheySimilarity + if method == "OnBit": + return DataStructs.BulkOnBitSimilarity + if method == "RogotGoldberg": + return DataStructs.BulkRogotGoldbergSimilarity + if method == "Russel": + return DataStructs.BulkRusselSimilarity + if method == "Sokal": + return DataStructs.BulkSokalSimilarity + if method == "Tanimoto" or method == "Jaccard": + return DataStructs.BulkTanimotoSimilarity + if method == "Tversky": + return DataStructs.BulkTverskySimilarity + raise ValueError(f"Unknown method {method}") + + +def iterable2intvect(it): + """ + Convert an iterable to an RDKit LongSparseIntVect. + + Args: + it: The iterable to convert. + + Returns: + The RDKit LongSparseIntVect. + """ + output = DataStructs.LongSparseIntVect(len(it)) + for i, v in enumerate(it): + output[i] = max(-2_147_483_648, min(2_147_483_647, int(v))) + return output + + +def iterable2bitvect(it): + """ + Convert an iterable to an RDKit ExplicitBitVect. + + Args: + it: The iterable to convert. + + Returns: + The RDKit ExplicitBitVect. + """ + output = DataStructs.ExplicitBitVect(len(it)) + output.SetBitsFromList([i for i, v in enumerate(it) if v]) + return output + + +def run_tanimoto(dataset: DataSet, method: SIM_OPTIONS = "Tanimoto") -> None: + """ + Compute pairwise Tanimoto-Scores of the given dataset. + + Args: + dataset: The dataset to compute pairwise, elementwise similarities for + method: The similarity measure to use. Default is "Tanimoto". + """ + LOGGER.info("Start Tanimoto clustering") + + embed = dataset.data[dataset.names[0]] + if isinstance(embed, (list, tuple, np.ndarray)): + if isinstance(embed[0], int) or np.issubdtype(embed[0].dtype, int): + if method in ["AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Kulczynski", "McConnaughey", "OnBit", + "RogotGoldberg", "Russel", "Sokal"]: + dataset.data = {k: iterable2bitvect(v) for k, v in dataset.data.items()} + else: + dataset.data = {k: iterable2intvect(v) for k, v in dataset.data.items()} + embed = dataset.data[dataset.names[0]] + else: + raise ValueError("Embeddings with non-integer elements are not supported at the moment.") + if not isinstance(embed, + (DataStructs.ExplicitBitVect, DataStructs.LongSparseIntVect, DataStructs.IntSparseIntVect)): + raise ValueError(f"Unsupported embedding type {type(embed)}. Please use either RDKit datastructures, lists, " + f"tuples or one-dimensional numpy arrays.") + fps = [dataset.data[name] for name in dataset.names] + run(dataset, fps, method) + + dataset.cluster_names = copy.deepcopy(dataset.names) + dataset.cluster_map = dict((n, n) for n in dataset.names) + + +def run(dataset, fps, method): + """ + Compute pairwise similarities of the given fingerprints. + + Args: + dataset: The dataset to compute pairwise similarities for. + fps: The fingerprints to compute pairwise similarities for. + method: The similarity measure to use. + """ + fct = get_rdkit_fct(method) + dataset.cluster_similarity = np.zeros((len(dataset.data), len(dataset.data))) + for i in range(len(dataset.data)): + dataset.cluster_similarity[i, i] = 1 + dataset.cluster_similarity[i, :i] = fct(fps[i], fps[:i]) + dataset.cluster_similarity[:i, i] = dataset.cluster_similarity[i, :i] + + min_val = np.min(dataset.cluster_similarity) + max_val = np.max(dataset.cluster_similarity) + dataset.cluster_similarity = (dataset.cluster_similarity - min_val) / (max_val - min_val) diff --git a/datasail/reader/read_genomes.py b/datasail/reader/read_genomes.py index 02910e7..9e934b2 100644 --- a/datasail/reader/read_genomes.py +++ b/datasail/reader/read_genomes.py @@ -1,13 +1,8 @@ -import pickle -from pathlib import Path -from typing import List, Tuple, Optional, Generator, Callable, Iterable - -import h5py +from typing import List, Tuple, Optional from datasail.reader.read_molecules import remove_duplicate_values -from datasail.reader.read_proteins import parse_fasta -from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_folder, read_csv -from datasail.settings import G_TYPE, UNK_LOCATION, FORM_FASTA, FASTA_FORMATS, FORM_GENOMES +from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_folder, read_data_input +from datasail.settings import G_TYPE, UNK_LOCATION, FORM_FASTA, FORM_GENOMES def read_genome_data( @@ -39,38 +34,12 @@ def read_genome_data( A dataset storing all information on that datatype """ dataset = DataSet(type=G_TYPE, location=UNK_LOCATION, format=FORM_FASTA) - if isinstance(data, Path): - if data.is_file(): - if data.suffix[1:].lower() in FASTA_FORMATS: - dataset.data = parse_fasta(data) - elif data.suffix[1:].lower() == "tsv": - dataset.data = dict(read_csv(data, sep="\t")) - elif data.suffix[1:].lower() == "csv": - dataset.data = dict(read_csv(data, sep=",")) - elif data.suffix[1:].lower() == "pkl": - with open(data, "rb") as file: - dataset.data = dict(pickle.load(file)) - elif data.suffix[1:].lower() == "h5": - with h5py.File(data) as file: - dataset.data = dict(file[k] for k in file.keys()) - else: - raise ValueError() - elif data.is_dir(): - dataset.data = dict(read_folder(data)) - dataset.format = FORM_GENOMES - else: - raise ValueError() - dataset.location = data - elif (isinstance(data, list) or isinstance(data, tuple)) and isinstance(data[0], Iterable) and len(data[0]) == 2: - dataset.data = dict(data) - elif isinstance(data, dict): - dataset.data = data - elif isinstance(data, Callable): - dataset.data = data() - elif isinstance(data, Generator): - dataset.data = dict(data) - else: - raise ValueError() + + def read_dir(ds): + ds.data = dict(read_folder(data)) + ds.format = FORM_GENOMES + + read_data_input(data, dataset, read_dir) dataset = read_data(weights, strats, sim, dist, inter, index, num_clusters, tool_args, dataset) dataset = remove_duplicate_values(dataset, dataset.data) diff --git a/datasail/reader/read_molecules.py b/datasail/reader/read_molecules.py index fb4f65a..86afdd6 100644 --- a/datasail/reader/read_molecules.py +++ b/datasail/reader/read_molecules.py @@ -1,8 +1,5 @@ -import pickle -from pathlib import Path -from typing import List, Tuple, Optional, Callable, Generator, Iterable +from typing import List, Tuple, Optional -import h5py import numpy as np from rdkit import Chem from rdkit.Chem import MolFromMol2File, MolFromMolFile, MolFromPDBFile, MolFromTPLFile, MolFromXYZFile @@ -11,7 +8,7 @@ except ImportError: MolFromMrvFile = None -from datasail.reader.utils import read_csv, DataSet, read_data, DATA_INPUT, MATRIX_INPUT +from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_data_input from datasail.settings import M_TYPE, UNK_LOCATION, FORM_SMILES @@ -55,42 +52,18 @@ def read_molecule_data( A dataset storing all information on that datatype """ dataset = DataSet(type=M_TYPE, format=FORM_SMILES, location=UNK_LOCATION) - if isinstance(data, Path): - if data.is_file(): - if data.suffix[1:].lower() == "tsv": - dataset.data = dict(read_csv(data, sep="\t")) - elif data.suffix[1:].lower() == "csv": - dataset.data = dict(read_csv(data, sep=",")) - elif data.suffix[1:].lower() == "pkl": - with open(data, "rb") as file: - dataset.data = dict(pickle.load(file)) - elif data.suffix[1:].lower() == "h5": - with h5py.File(data) as file: - dataset.data = dict(file[k] for k in file.keys()) + + def read_dir(ds: DataSet): + ds.data = {} + for file in data.iterdir(): + if file.suffix[1:].lower() != "sdf" and mol_reader[file.suffix[1:].lower()] is not None: + ds.data[file.stem] = mol_reader[file.suffix[1:].lower()](file) else: - raise ValueError() - elif data.is_dir(): - dataset.data = {} - for file in data.iterdir(): - if file.suffix[1:].lower() != "sdf" and mol_reader[file.suffix[1:].lower()] is not None: - dataset.data[file.stem] = mol_reader[file.suffix[1:].lower()](file) - else: - suppl = Chem.SDMolSupplier(file) - for i, mol in enumerate(suppl): - dataset.data[f"{file.stem}_{i}"] = mol - else: - raise ValueError() - dataset.location = data - elif (isinstance(data, list) or isinstance(data, tuple)) and isinstance(data[0], Iterable) and len(data[0]) == 2: - dataset.data = dict(data) - elif isinstance(data, dict): - dataset.data = data - elif isinstance(data, Callable): - dataset.data = data() - elif isinstance(data, Generator): - dataset.data = dict(data) - else: - raise ValueError() + suppl = Chem.SDMolSupplier(file) + for i, mol in enumerate(suppl): + ds.data[f"{file.stem}_{i}"] = mol + + read_data_input(data, dataset, read_dir) dataset = read_data(weights, strats, sim, dist, inter, index, num_clusters, tool_args, dataset) dataset = remove_molecule_duplicates(dataset) @@ -109,6 +82,10 @@ def remove_molecule_duplicates(dataset: DataSet) -> DataSet: Returns: Update arguments as teh location of the data might change and an ID-Map file might be added. """ + if isinstance(dataset.data[dataset.names[0]], (list, tuple, np.ndarray)): + # TODO: proper check for duplicate embeddings + dataset.id_map = {n: n for n in dataset.names} + return dataset # Extract invalid molecules non_mols = [] diff --git a/datasail/reader/read_other.py b/datasail/reader/read_other.py index 4140085..c96ecae 100644 --- a/datasail/reader/read_other.py +++ b/datasail/reader/read_other.py @@ -1,9 +1,8 @@ -from pathlib import Path -from typing import List, Tuple, Optional, Generator, Callable +from typing import List, Tuple, Optional from datasail.reader.read_genomes import read_folder from datasail.reader.read_molecules import remove_duplicate_values -from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT +from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_data_input from datasail.settings import O_TYPE, UNK_LOCATION, FORM_OTHER @@ -28,7 +27,6 @@ def read_other_data( strats: Stratification for the data sim: Similarity file or metric dist: Distance file or metric - id_map: Mapping of ids in case of duplicates in the dataset inter: Interaction, alternative way to compute weights index: Index of the entities in the interaction file num_clusters: Number of clusters to compute for this dataset @@ -38,20 +36,11 @@ def read_other_data( A dataset storing all information on that datatype """ dataset = DataSet(type=O_TYPE, location=UNK_LOCATION, format=FORM_OTHER) - if isinstance(data, Path): - if data.exists(): - dataset.data = read_folder(data) - dataset.location = data - else: - raise ValueError() - elif isinstance(data, dict): - dataset.data = data - elif isinstance(data, Callable): - dataset.data = data() - elif isinstance(data, Generator): - dataset.data = dict(data) - else: - raise ValueError() + + def read_dir(ds): + ds.data = dict(read_folder(data)) + + read_data_input(data, dataset, read_dir) dataset, inter = read_data(weights, strats, sim, dist, inter, index, num_clusters, tool_args, dataset) dataset = remove_duplicate_values(dataset, dataset.data) diff --git a/datasail/reader/read_proteins.py b/datasail/reader/read_proteins.py index e9026b3..d73212e 100644 --- a/datasail/reader/read_proteins.py +++ b/datasail/reader/read_proteins.py @@ -1,13 +1,11 @@ -import pickle from pathlib import Path -from typing import Generator, Tuple, Dict, List, Optional, Set, Callable, Iterable +from typing import Tuple, Dict, List, Optional, Set import numpy as np -import h5py from datasail.reader.read_molecules import remove_duplicate_values -from datasail.reader.utils import read_csv, DataSet, read_data, read_folder, DATA_INPUT, MATRIX_INPUT -from datasail.settings import P_TYPE, UNK_LOCATION, FORM_PDB, FORM_FASTA, FASTA_FORMATS +from datasail.reader.utils import DataSet, read_data, read_folder, DATA_INPUT, MATRIX_INPUT, read_data_input +from datasail.settings import P_TYPE, UNK_LOCATION, FORM_PDB, FORM_FASTA def read_protein_data( @@ -39,37 +37,11 @@ def read_protein_data( A dataset storing all information on that datatype """ dataset = DataSet(type=P_TYPE, location=UNK_LOCATION) - if isinstance(data, Path): - if data.is_file(): - if data.suffix[1:] in FASTA_FORMATS: - dataset.data = parse_fasta(data) - elif data.suffix[1:].lower() == "tsv": - dataset.data = dict(read_csv(data, sep="\t")) - elif data.suffix[1:].lower() == "csv": - dataset.data = dict(read_csv(data, sep=",")) - elif data.suffix[1:].lower() == "pkl": - with open(data, "rb") as file: - dataset.data = dict(pickle.load(file)) - elif data.suffix[1:].lower() == "h5": - with h5py.File(data) as file: - dataset.data = dict(file[k] for k in file.keys()) - else: - raise ValueError() - elif data.is_dir(): - dataset.data = dict(read_folder(data, "pdb")) - else: - raise ValueError() - dataset.location = data - elif (isinstance(data, list) or isinstance(data, tuple)) and isinstance(data[0], Iterable) and len(data[0]) == 2: - dataset.data = dict(data) - elif isinstance(data, dict): - dataset.data = data - elif isinstance(data, Callable): - dataset.data = data() - elif isinstance(data, Generator): - dataset.data = dict(data) - else: - raise ValueError() + + def read_dir(ds): + ds.data = dict(read_folder(data, "pdb")) + + read_data_input(data, dataset, read_dir) dataset.format = FORM_PDB if str(next(iter(dataset.data.values()))).endswith(".pdb") else FORM_FASTA @@ -79,32 +51,6 @@ def read_protein_data( return dataset -def parse_fasta(path: Path = None) -> Dict[str, str]: - """ - Parse a FASTA file and do some validity checks if requested. - - Args: - path: Path to the FASTA file - - Returns: - Dictionary mapping sequences IDs to amino acid sequences - """ - seq_map = {} - - with open(path, "r") as fasta: - for line in fasta.readlines(): - line = line.strip() - if len(line) == 0: - continue - if line[0] == '>': - entry_id = line[1:] # .replace(" ", "_") - seq_map[entry_id] = '' - else: - seq_map[entry_id] += line - - return seq_map - - def check_pdb_pair(pdb_seqs1: List[str], pdb_seqs2: List[str]) -> bool: """ Entry point for the comparison of two PDB files. diff --git a/datasail/reader/utils.py b/datasail/reader/utils.py index fac4872..71d5967 100644 --- a/datasail/reader/utils.py +++ b/datasail/reader/utils.py @@ -1,13 +1,15 @@ +import pickle from argparse import Namespace from dataclasses import dataclass, fields from pathlib import Path -from typing import Generator, Tuple, List, Optional, Dict, Union, Any, Callable +from typing import Generator, Tuple, List, Optional, Dict, Union, Any, Callable, Iterable +import h5py import numpy as np import pandas as pd from datasail.reader.validate import validate_user_args -from datasail.settings import get_default, SIM_ALGOS, DIST_ALGOS, UNK_LOCATION, format2ending +from datasail.settings import get_default, SIM_ALGOS, DIST_ALGOS, UNK_LOCATION, format2ending, FASTA_FORMATS DATA_INPUT = Optional[Union[str, Path, Dict[str, Union[str, np.ndarray]], Callable[..., Dict[str, Union[str, np.ndarray]]], Generator[Tuple[str, Union[str, np.ndarray]], None, None]]] MATRIX_INPUT = Optional[Union[str, Path, Tuple[List[str], np.ndarray], Callable[..., Tuple[List[str], np.ndarray]]]] @@ -358,6 +360,74 @@ def read_folder(folder_path: Path, file_extension: Optional[str] = None) -> Gene yield filename.stem, filename +def read_data_input(data: DATA_INPUT, dataset: DataSet, read_dir: Callable[[DataSet], None]): + """ + Read in the data from different sources and store it in the dataset. + + Args: + data: Data input + dataset: Dataset to store the data in + read_dir: Function to read in a directory + """ + if isinstance(data, Path): + if data.is_file(): + if data.suffix[1:] in FASTA_FORMATS: + dataset.data = parse_fasta(data) + elif data.suffix[1:].lower() == "tsv": + dataset.data = dict(read_csv(data, sep="\t")) + elif data.suffix[1:].lower() == "csv": + dataset.data = dict(read_csv(data, sep=",")) + elif data.suffix[1:].lower() == "pkl": + with open(data, "rb") as file: + dataset.data = dict(pickle.load(file)) + elif data.suffix[1:].lower() == "h5": + with h5py.File(data) as file: + dataset.data = {k: np.array(file[k]) for k in file.keys()} + else: + raise ValueError("Unknown file format. Supported formats are: .fasta, .fna, .fa, tsv, .csv, .pkl, .h5") + elif data.is_dir(): + read_dir(dataset) + else: + raise ValueError("Unknown data input type. Path encodes neither a file nor a directory.") + dataset.location = data + elif (isinstance(data, list) or isinstance(data, tuple)) and isinstance(data[0], Iterable) and len(data[0]) == 2: + dataset.data = dict(data) + elif isinstance(data, dict): + dataset.data = data + elif isinstance(data, Callable): + dataset.data = data() + elif isinstance(data, Generator): + dataset.data = dict(data) + else: + raise ValueError("Unknown data input type.") + + +def parse_fasta(path: Path = None) -> Dict[str, str]: + """ + Parse a FASTA file and do some validity checks if requested. + + Args: + path: Path to the FASTA file + + Returns: + Dictionary mapping sequences IDs to amino acid sequences + """ + seq_map = {} + + with open(path, "r") as fasta: + for line in fasta.readlines(): + line = line.strip() + if len(line) == 0: + continue + if line[0] == '>': + entry_id = line[1:] # .replace(" ", "_") + seq_map[entry_id] = '' + else: + seq_map[entry_id] += line + + return seq_map + + def get_prefix_args(prefix, **kwargs) -> Dict[str, Any]: """ Remove prefix from keys and return those key-value-pairs. diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 6092f20..f0217cb 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -4,6 +4,9 @@ import numpy as np import pandas as pd import pytest +from rdkit import Chem +from rdkit.Chem import AllChem, Descriptors +from rdkit.ML.Descriptors import MoleculeDescriptors from datasail.cluster.cdhit import run_cdhit from datasail.cluster.clustering import additional_clustering, cluster @@ -13,10 +16,12 @@ from datasail.cluster.mash import run_mash from datasail.cluster.mmseqs2 import run_mmseqs from datasail.cluster.mmseqspp import run_mmseqspp +from datasail.cluster.vectors import run_tanimoto from datasail.cluster.tmalign import run_tmalign from datasail.cluster.wlk import run_wlk -from datasail.reader.read_proteins import parse_fasta, read_folder -from datasail.reader.utils import DataSet, read_csv + +from datasail.reader.read_proteins import read_folder +from datasail.reader.utils import DataSet, read_csv, parse_fasta from datasail.reader.validate import check_cdhit_arguments, check_foldseek_arguments, check_mmseqs_arguments, \ check_mash_arguments, check_mmseqspp_arguments, check_diamond_arguments from datasail.sail import datasail @@ -24,7 +29,12 @@ DIAMOND -@pytest.mark.todo +@pytest.fixture() +def md_calculator(): + descriptor_names = [desc[0] for desc in Descriptors._descList] + return MoleculeDescriptors.MolecularDescriptorCalculator(descriptor_names) + + def test_additional_clustering(): names = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"] base_map = dict((n, n) for n in names) @@ -82,7 +92,6 @@ def test_additional_clustering(): assert np.min(s_dataset.cluster_similarity) == 0 assert np.max(s_dataset.cluster_similarity) == 1 assert s_dataset.cluster_distance is None - # assert [s_dataset.cluster_weights[i] for i in s_dataset.cluster_names] == [18, 12, 6, 12, 4] d_dataset = additional_clustering(d_dataset, n_clusters=5, linkage="average") assert len(d_dataset.cluster_names) == 5 @@ -95,7 +104,6 @@ def test_additional_clustering(): assert d_dataset.cluster_similarity is None assert np.min(d_dataset.cluster_distance) == 0 assert np.max(d_dataset.cluster_distance) == 1 - # assert [d_dataset.cluster_weights[i] for i in d_dataset.cluster_names] == [16, 36] def protein_fasta_data(algo): @@ -130,7 +138,6 @@ def protein_pdb_data(algo): ) -@pytest.fixture def molecule_data(): data = dict((k, v) for k, v in read_csv(Path("data") / "pipeline" / "drugs.tsv", "\t")) return DataSet( @@ -141,7 +148,6 @@ def molecule_data(): ) -@pytest.fixture def genome_fasta_data(): data = dict((k, v) for k, v in read_folder(Path("data") / "genomes", "fna")) return DataSet( @@ -164,16 +170,18 @@ def test_cdhit_protein(): @pytest.mark.todo @pytest.mark.nowin -def test_cdhit_genome(genome_fasta_data): +def test_cdhit_genome(): + data = genome_fasta_data() if platform.system() == "Windows": pytest.skip("CD-HIT is not supported on Windows") - run_cdhit(genome_fasta_data, 1, Path()) - check_clustering(genome_fasta_data) + run_cdhit(data, 1, Path()) + check_clustering(data) -def test_ecfp_molecule(molecule_data): - run_ecfp(molecule_data) - check_clustering(molecule_data) +def test_ecfp_molecule(): + data = molecule_data() + run_ecfp(data) + check_clustering(data) @pytest.mark.nowin @@ -220,6 +228,29 @@ def test_mmseqspp_protein(): check_clustering(data) +@pytest.mark.parametrize("algo", ["FP", "MD"]) +@pytest.mark.parametrize("in_type", ["Original", "List", "Numpy"]) +@pytest.mark.parametrize("method", ["AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Dice", "Kulczynski", + "McConnaughey", "OnBit", "RogotGoldberg", "Russel", "Sokal", "Tanimoto"]) +def test_tanimoto(algo, in_type, method, md_calculator): + if algo == "MD" and in_type == "Original": + pytest.skip("Molecular descriptors cannot directly be used as input.") + data = molecule_data() + if algo == "FP": + embed = lambda x: AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) + else: + embed = lambda x: md_calculator.CalcDescriptors(Chem.MolFromSmiles(x)) + if in_type == "Original": + wrap = lambda x: x + elif in_type == "List": + wrap = lambda x: [int(y) for y in x] + else: + wrap = lambda x: np.array(list(x)).astype(int) + data.data = dict((k, wrap(embed(v))) for k, v in data.data.items()) + run_tanimoto(data, method) + check_clustering(data) + + @pytest.mark.nowin @pytest.mark.todo def test_tmalign_protein(): diff --git a/tests/test_clustom_args.py b/tests/test_custom_args.py similarity index 100% rename from tests/test_clustom_args.py rename to tests/test_custom_args.py diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1e882b8..6e8fe4d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,8 +1,15 @@ +import pickle import shutil from pathlib import Path +import h5py +import pandas as pd import pytest +from rdkit import Chem +from rdkit.Chem import AllChem, Descriptors +from rdkit.ML.Descriptors import MoleculeDescriptors +from datasail.reader.read_molecules import read_molecule_data from datasail.sail import sail, datasail from tests.utils import check_folder, run_sail @@ -219,6 +226,47 @@ def test_report_repeat(): assert len(list(c2.iterdir())) == 11 +@pytest.fixture() +def md_calculator(): + descriptor_names = [desc[0] for desc in Descriptors._descList] + return MoleculeDescriptors.MolecularDescriptorCalculator(descriptor_names) + + +@pytest.mark.parametrize("mode", ["CSV", "TSV", "PKL", "H5PY"]) +def test_input_formats(mode, md_calculator): + base = Path("data") / "pipeline" + drugs = pd.read_csv(base / "drugs.tsv", sep="\t") + ddict = {row["Drug_ID"]: row["SMILES"] for index, row in drugs.iterrows()} + (base / "input_forms").mkdir(exist_ok=True, parents=True) + + if mode == "CSV": + filepath = base / "input_forms" / "drugs.csv" + drugs.to_csv(filepath, sep=",", index=False) + elif mode == "TSV": + filepath = base / "input_forms" / "drugs.tsv" + drugs.to_csv(filepath, sep="\t", index=False) + elif mode == "PKL": + data = {} + for k, v in ddict.items(): + data[k] = AllChem.MolToSmiles(Chem.MolFromSmiles(v)) + filepath = base / "input_forms" / "drugs.pkl" + with open(filepath, "wb") as f: + pickle.dump(data, f) + elif mode == "H5PY": + filepath = base / "input_forms" / "drugs.h5" + with h5py.File(filepath, "w") as f: + for k, v in ddict.items(): + f[k] = list(md_calculator.CalcDescriptors(Chem.MolFromSmiles(v))) + else: + raise ValueError(f"Unknown mode: {mode}") + + dataset = read_molecule_data(filepath) + + shutil.rmtree(base / "input_forms", ignore_errors=True) + + assert set(dataset.names) == set(ddict.keys()) + + @pytest.mark.todo def test_genomes(): base = Path("data") / "genomes"