Skip to content

Commit

Permalink
Extending handling of embedding at input plus more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Joeres authored and Roman Joeres committed Mar 15, 2024
1 parent 0478d6c commit 4696362
Show file tree
Hide file tree
Showing 12 changed files with 347 additions and 215 deletions.
2 changes: 1 addition & 1 deletion datasail/cluster/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datasail.cluster.mash import run_mash
from datasail.cluster.mmseqs2 import run_mmseqs
from datasail.cluster.mmseqspp import run_mmseqspp
from datasail.cluster.tanimoto import run_tanimoto
from datasail.cluster.vectors import run_tanimoto
from datasail.cluster.utils import heatmap
from datasail.cluster.wlk import run_wlk
from datasail.reader.utils import DataSet
Expand Down
15 changes: 5 additions & 10 deletions datasail/cluster/ecfp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import numpy as np
from rdkit import Chem, DataStructs, RDLogger
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds.MurckoScaffold import MakeScaffoldGeneric
from rdkit.Chem.rdchem import MolSanitizeException

from datasail.cluster.utils import read_molecule_encoding
from datasail.cluster.vectors import run, SIM_OPTIONS
from datasail.reader.utils import DataSet
from datasail.settings import LOGGER


def run_ecfp(dataset: DataSet) -> None:
def run_ecfp(dataset: DataSet, method: SIM_OPTIONS = "Tanimoto") -> None:
"""
Compute 1024Bit-ECPFs for every molecule in the dataset and then compute pairwise Tanimoto-Scores of them.
Args:
dataset: The dataset to compute pairwise, elementwise similarities for
method: The similarity measure to use. Default is "Tanimoto".
"""
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
Expand Down Expand Up @@ -54,14 +55,8 @@ def run_ecfp(dataset: DataSet) -> None:
fps.append(AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(scaffold), 2, nBits=1024))

LOGGER.info(f"Reduced {len(dataset.names)} molecules to {len(dataset.cluster_names)}")

LOGGER.info("Compute Tanimoto Coefficients")

count = len(dataset.cluster_names)
dataset.cluster_similarity = np.zeros((count, count))
for i in range(count):
dataset.cluster_similarity[i, i] = 1
dataset.cluster_similarity[i, :i] = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
dataset.cluster_similarity[:i, i] = dataset.cluster_similarity[i, :i]
run(dataset, fps, method)

dataset.cluster_map = dict((name, Chem.MolToSmiles(scaffolds[name])) for name in dataset.names)
29 changes: 0 additions & 29 deletions datasail/cluster/tanimoto.py

This file was deleted.

136 changes: 136 additions & 0 deletions datasail/cluster/vectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import copy
from typing import Literal

import numpy as np
from rdkit import DataStructs

from datasail.reader.utils import DataSet
from datasail.settings import LOGGER

SIM_OPTIONS = Literal[
"AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Dice", "Kulczynski", "McConnaughey", "OnBit", "RogotGoldberg",
"Russel", "Sokal", "Tanimoto", "Jaccard"
]


def get_rdkit_fct(method: SIM_OPTIONS):
"""
Get the RDKit function for the given similarity measure.
Args:
method: The name of the similarity measure to get the function for.
Returns:
The RDKit function for the given similarity measure.
"""
if method == "AllBit":
return DataStructs.BulkAllBitSimilarity
if method == "Asymmetric":
return DataStructs.BulkAsymmetricSimilarity
if method == "BraunBlanquet":
return DataStructs.BulkBraunBlanquetSimilarity
if method == "Cosine":
return DataStructs.BulkCosineSimilarity
if method == "Dice":
return DataStructs.BulkDiceSimilarity
if method == "Kulczynski":
return DataStructs.BulkKulczynskiSimilarity
if method == "McConnaughey":
return DataStructs.BulkMcConnaugheySimilarity
if method == "OnBit":
return DataStructs.BulkOnBitSimilarity
if method == "RogotGoldberg":
return DataStructs.BulkRogotGoldbergSimilarity
if method == "Russel":
return DataStructs.BulkRusselSimilarity
if method == "Sokal":
return DataStructs.BulkSokalSimilarity
if method == "Tanimoto" or method == "Jaccard":
return DataStructs.BulkTanimotoSimilarity
if method == "Tversky":
return DataStructs.BulkTverskySimilarity
raise ValueError(f"Unknown method {method}")


def iterable2intvect(it):
"""
Convert an iterable to an RDKit LongSparseIntVect.
Args:
it: The iterable to convert.
Returns:
The RDKit LongSparseIntVect.
"""
output = DataStructs.LongSparseIntVect(len(it))
for i, v in enumerate(it):
output[i] = max(-2_147_483_648, min(2_147_483_647, int(v)))
return output


def iterable2bitvect(it):
"""
Convert an iterable to an RDKit ExplicitBitVect.
Args:
it: The iterable to convert.
Returns:
The RDKit ExplicitBitVect.
"""
output = DataStructs.ExplicitBitVect(len(it))
output.SetBitsFromList([i for i, v in enumerate(it) if v])
return output


def run_tanimoto(dataset: DataSet, method: SIM_OPTIONS = "Tanimoto") -> None:
"""
Compute pairwise Tanimoto-Scores of the given dataset.
Args:
dataset: The dataset to compute pairwise, elementwise similarities for
method: The similarity measure to use. Default is "Tanimoto".
"""
LOGGER.info("Start Tanimoto clustering")

embed = dataset.data[dataset.names[0]]
if isinstance(embed, (list, tuple, np.ndarray)):
if isinstance(embed[0], int) or np.issubdtype(embed[0].dtype, int):
if method in ["AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Kulczynski", "McConnaughey", "OnBit",
"RogotGoldberg", "Russel", "Sokal"]:
dataset.data = {k: iterable2bitvect(v) for k, v in dataset.data.items()}
else:
dataset.data = {k: iterable2intvect(v) for k, v in dataset.data.items()}
embed = dataset.data[dataset.names[0]]
else:
raise ValueError("Embeddings with non-integer elements are not supported at the moment.")
if not isinstance(embed,
(DataStructs.ExplicitBitVect, DataStructs.LongSparseIntVect, DataStructs.IntSparseIntVect)):
raise ValueError(f"Unsupported embedding type {type(embed)}. Please use either RDKit datastructures, lists, "
f"tuples or one-dimensional numpy arrays.")
fps = [dataset.data[name] for name in dataset.names]
run(dataset, fps, method)

dataset.cluster_names = copy.deepcopy(dataset.names)
dataset.cluster_map = dict((n, n) for n in dataset.names)


def run(dataset, fps, method):
"""
Compute pairwise similarities of the given fingerprints.
Args:
dataset: The dataset to compute pairwise similarities for.
fps: The fingerprints to compute pairwise similarities for.
method: The similarity measure to use.
"""
fct = get_rdkit_fct(method)
dataset.cluster_similarity = np.zeros((len(dataset.data), len(dataset.data)))
for i in range(len(dataset.data)):
dataset.cluster_similarity[i, i] = 1
dataset.cluster_similarity[i, :i] = fct(fps[i], fps[:i])
dataset.cluster_similarity[:i, i] = dataset.cluster_similarity[i, :i]

min_val = np.min(dataset.cluster_similarity)
max_val = np.max(dataset.cluster_similarity)
dataset.cluster_similarity = (dataset.cluster_similarity - min_val) / (max_val - min_val)
49 changes: 9 additions & 40 deletions datasail/reader/read_genomes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import pickle
from pathlib import Path
from typing import List, Tuple, Optional, Generator, Callable, Iterable

import h5py
from typing import List, Tuple, Optional

from datasail.reader.read_molecules import remove_duplicate_values
from datasail.reader.read_proteins import parse_fasta
from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_folder, read_csv
from datasail.settings import G_TYPE, UNK_LOCATION, FORM_FASTA, FASTA_FORMATS, FORM_GENOMES
from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_folder, read_data_input
from datasail.settings import G_TYPE, UNK_LOCATION, FORM_FASTA, FORM_GENOMES


def read_genome_data(
Expand Down Expand Up @@ -39,38 +34,12 @@ def read_genome_data(
A dataset storing all information on that datatype
"""
dataset = DataSet(type=G_TYPE, location=UNK_LOCATION, format=FORM_FASTA)
if isinstance(data, Path):
if data.is_file():
if data.suffix[1:].lower() in FASTA_FORMATS:
dataset.data = parse_fasta(data)
elif data.suffix[1:].lower() == "tsv":
dataset.data = dict(read_csv(data, sep="\t"))
elif data.suffix[1:].lower() == "csv":
dataset.data = dict(read_csv(data, sep=","))
elif data.suffix[1:].lower() == "pkl":
with open(data, "rb") as file:
dataset.data = dict(pickle.load(file))
elif data.suffix[1:].lower() == "h5":
with h5py.File(data) as file:
dataset.data = dict(file[k] for k in file.keys())
else:
raise ValueError()
elif data.is_dir():
dataset.data = dict(read_folder(data))
dataset.format = FORM_GENOMES
else:
raise ValueError()
dataset.location = data
elif (isinstance(data, list) or isinstance(data, tuple)) and isinstance(data[0], Iterable) and len(data[0]) == 2:
dataset.data = dict(data)
elif isinstance(data, dict):
dataset.data = data
elif isinstance(data, Callable):
dataset.data = data()
elif isinstance(data, Generator):
dataset.data = dict(data)
else:
raise ValueError()

def read_dir(ds):
ds.data = dict(read_folder(data))
ds.format = FORM_GENOMES

read_data_input(data, dataset, read_dir)

dataset = read_data(weights, strats, sim, dist, inter, index, num_clusters, tool_args, dataset)
dataset = remove_duplicate_values(dataset, dataset.data)
Expand Down
57 changes: 17 additions & 40 deletions datasail/reader/read_molecules.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import pickle
from pathlib import Path
from typing import List, Tuple, Optional, Callable, Generator, Iterable
from typing import List, Tuple, Optional

import h5py
import numpy as np
from rdkit import Chem
from rdkit.Chem import MolFromMol2File, MolFromMolFile, MolFromPDBFile, MolFromTPLFile, MolFromXYZFile
Expand All @@ -11,7 +8,7 @@
except ImportError:
MolFromMrvFile = None

from datasail.reader.utils import read_csv, DataSet, read_data, DATA_INPUT, MATRIX_INPUT
from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_data_input
from datasail.settings import M_TYPE, UNK_LOCATION, FORM_SMILES


Expand Down Expand Up @@ -55,42 +52,18 @@ def read_molecule_data(
A dataset storing all information on that datatype
"""
dataset = DataSet(type=M_TYPE, format=FORM_SMILES, location=UNK_LOCATION)
if isinstance(data, Path):
if data.is_file():
if data.suffix[1:].lower() == "tsv":
dataset.data = dict(read_csv(data, sep="\t"))
elif data.suffix[1:].lower() == "csv":
dataset.data = dict(read_csv(data, sep=","))
elif data.suffix[1:].lower() == "pkl":
with open(data, "rb") as file:
dataset.data = dict(pickle.load(file))
elif data.suffix[1:].lower() == "h5":
with h5py.File(data) as file:
dataset.data = dict(file[k] for k in file.keys())

def read_dir(ds: DataSet):
ds.data = {}
for file in data.iterdir():
if file.suffix[1:].lower() != "sdf" and mol_reader[file.suffix[1:].lower()] is not None:
ds.data[file.stem] = mol_reader[file.suffix[1:].lower()](file)
else:
raise ValueError()
elif data.is_dir():
dataset.data = {}
for file in data.iterdir():
if file.suffix[1:].lower() != "sdf" and mol_reader[file.suffix[1:].lower()] is not None:
dataset.data[file.stem] = mol_reader[file.suffix[1:].lower()](file)
else:
suppl = Chem.SDMolSupplier(file)
for i, mol in enumerate(suppl):
dataset.data[f"{file.stem}_{i}"] = mol
else:
raise ValueError()
dataset.location = data
elif (isinstance(data, list) or isinstance(data, tuple)) and isinstance(data[0], Iterable) and len(data[0]) == 2:
dataset.data = dict(data)
elif isinstance(data, dict):
dataset.data = data
elif isinstance(data, Callable):
dataset.data = data()
elif isinstance(data, Generator):
dataset.data = dict(data)
else:
raise ValueError()
suppl = Chem.SDMolSupplier(file)
for i, mol in enumerate(suppl):
ds.data[f"{file.stem}_{i}"] = mol

read_data_input(data, dataset, read_dir)

dataset = read_data(weights, strats, sim, dist, inter, index, num_clusters, tool_args, dataset)
dataset = remove_molecule_duplicates(dataset)
Expand All @@ -109,6 +82,10 @@ def remove_molecule_duplicates(dataset: DataSet) -> DataSet:
Returns:
Update arguments as teh location of the data might change and an ID-Map file might be added.
"""
if isinstance(dataset.data[dataset.names[0]], (list, tuple, np.ndarray)):
# TODO: proper check for duplicate embeddings
dataset.id_map = {n: n for n in dataset.names}
return dataset

# Extract invalid molecules
non_mols = []
Expand Down
Loading

0 comments on commit 4696362

Please sign in to comment.