Skip to content

Commit

Permalink
Minor update and extension of embedding clustering with scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Mar 20, 2024
1 parent 349007b commit d8b5cee
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 55 deletions.
5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Contributing to DataSAIL
========================

As with every other open-source project, you can contribute to DataSAIL. For more information, we refer you to the
[contribution guidelines in our documentation]().
4 changes: 2 additions & 2 deletions datasail/cluster/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datasail.cluster.mash import run_mash
from datasail.cluster.mmseqs2 import run_mmseqs
from datasail.cluster.mmseqspp import run_mmseqspp
from datasail.cluster.vectors import run_tanimoto
from datasail.cluster.vectors import run_vector
from datasail.cluster.utils import heatmap
from datasail.cluster.wlk import run_wlk
from datasail.reader.utils import DataSet
Expand Down Expand Up @@ -111,7 +111,7 @@ def similarity_clustering(dataset: DataSet, threads: int = 1, log_dir: Optional[
elif dataset.similarity.lower() == MMSEQSPP:
run_mmseqspp(dataset, threads, log_dir)
elif dataset.similarity.lower() == TANIMOTO:
run_tanimoto(dataset)
run_vector(dataset)
else:
raise ValueError(f"Unknown cluster method: {dataset.similarity}")

Expand Down
2 changes: 1 addition & 1 deletion datasail/cluster/diamond.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run_diamond(dataset: DataSet, threads: int, log_dir: Optional[Path] = None)
cmd = lambda x: f"mkdir {result_folder} && " \
f"cd {result_folder} && " \
f"diamond makedb --in ../diamond.fasta --db seqs.dmnd {makedb_args} {x} --threads {threads} && " \
f"diamond blastp --db seqs.dmnd --query {str(Path('..') / dataset.get_location_path())} --out alis.tsv --outfmt 6 qseqid sseqid pident " \
f"diamond blastp --db seqs.dmnd --query ../diamond.fasta --out alis.tsv --outfmt 6 qseqid sseqid pident " \
f"--threads {threads} {blastp_args} {x} && " \
f"rm ../diamond.fasta"

Expand Down
147 changes: 103 additions & 44 deletions datasail/cluster/vectors.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import copy
from typing import Literal
from typing import Literal, get_args, Union

import numpy as np
import scipy
from rdkit import DataStructs

from datasail.reader.utils import DataSet
from datasail.settings import LOGGER

SIM_OPTIONS = Literal[
"AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Dice", "Kulczynski", "McConnaughey", "OnBit", "RogotGoldberg",
"Russel", "Sokal", "Tanimoto", "Jaccard"
"allbit", "asymmetric", "braunblanquet", "cosine", "dice", "kulczynski", "mcconnaughey", "onbit", "rogotgoldberg",
"russel", "sokal"
]

# produces inf or nan: correlation, cosine, jensenshannon, seuclidean, braycurtis
# boolean only: dice, kulczynski1, rogerstanimoto, russelrao, sokalmichener, sokalsneath, yule
# matching == hamming, manhattan == cityblock (inofficial)
DIST_OPTIONS = Literal[
"canberra", "chebyshev", "cityblock", "euclidean", "hamming", "jaccard", "mahalanobis", "manhattan", "matching",
"minkowski", "sqeuclidean", "tanimoto"
]


Expand All @@ -23,35 +32,51 @@ def get_rdkit_fct(method: SIM_OPTIONS):
Returns:
The RDKit function for the given similarity measure.
"""
if method == "AllBit":
if method == "allbit":
return DataStructs.BulkAllBitSimilarity
if method == "Asymmetric":
if method == "asymmetric":
return DataStructs.BulkAsymmetricSimilarity
if method == "BraunBlanquet":
if method == "braunblanquet":
return DataStructs.BulkBraunBlanquetSimilarity
if method == "Cosine":
if method == "cosine":
return DataStructs.BulkCosineSimilarity
if method == "Dice":
if method == "dice":
return DataStructs.BulkDiceSimilarity
if method == "Kulczynski":
if method == "kulczynski":
return DataStructs.BulkKulczynskiSimilarity
if method == "McConnaughey":
if method == "mcconnaughey":
return DataStructs.BulkMcConnaugheySimilarity
if method == "OnBit":
if method == "onbit":
return DataStructs.BulkOnBitSimilarity
if method == "RogotGoldberg":
if method == "rogotgoldberg":
return DataStructs.BulkRogotGoldbergSimilarity
if method == "Russel":
if method == "russel":
return DataStructs.BulkRusselSimilarity
if method == "Sokal":
if method == "sokal":
return DataStructs.BulkSokalSimilarity
if method == "Tanimoto" or method == "Jaccard":
return DataStructs.BulkTanimotoSimilarity
if method == "Tversky":
return DataStructs.BulkTverskySimilarity
raise ValueError(f"Unknown method {method}")


def rdkit_sim(fps, method: SIM_OPTIONS) -> np.ndarray:
"""
Compute the similarity between elements of a list of rdkit vectors.
Args:
fps: List of RDKit vectors to fastly compute the similarity matrix
method: Name of the method to use for calculation
Returns:
"""
fct = get_rdkit_fct(method)
matrix = np.zeros((len(fps), len(fps)))
for i in range(len(fps)):
matrix[i, i] = 1
matrix[i, :i] = fct(fps[i], fps[:i])
matrix[:i, i] = matrix[i, :i]
return matrix


def iterable2intvect(it):
"""
Convert an iterable to an RDKit LongSparseIntVect.
Expand Down Expand Up @@ -83,39 +108,71 @@ def iterable2bitvect(it):
return output


def run_tanimoto(dataset: DataSet, method: SIM_OPTIONS = "Tanimoto") -> None:
def run_vector(dataset: DataSet, method: SIM_OPTIONS = "tanimoto") -> None:
"""
Compute pairwise Tanimoto-Scores of the given dataset.
Args:
dataset: The dataset to compute pairwise, elementwise similarities for
method: The similarity measure to use. Default is "Tanimoto".
method: The similarity measure to use. Default is "tanimoto".
"""
LOGGER.info("Start Tanimoto clustering")
method = method.lower()

embed = dataset.data[dataset.names[0]]
if isinstance(embed, (list, tuple, np.ndarray)):
if isinstance(embed[0], int) or np.issubdtype(embed[0].dtype, int):
if method in ["AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Kulczynski", "McConnaughey", "OnBit",
"RogotGoldberg", "Russel", "Sokal"]:
dataset.data = {k: iterable2bitvect(v) for k, v in dataset.data.items()}
if method in get_args(SIM_OPTIONS):
if isinstance(embed, (list, tuple, np.ndarray)):
if isinstance(embed[0], int) or np.issubdtype(embed[0].dtype, int):
if method in ["allbit", "asymmetric", "braunblanquet", "cosine", "kulczynski", "mcconnaughey", "onbit",
"rogotgoldberg", "russel", "sokal"]:
dataset.data = {k: iterable2bitvect(v) for k, v in dataset.data.items()}
else:
dataset.data = {k: iterable2intvect(v) for k, v in dataset.data.items()}
embed = dataset.data[dataset.names[0]]
else:
dataset.data = {k: iterable2intvect(v) for k, v in dataset.data.items()}
embed = dataset.data[dataset.names[0]]
else:
raise ValueError("Embeddings with non-integer elements are not supported at the moment.")
if not isinstance(embed,
(DataStructs.ExplicitBitVect, DataStructs.LongSparseIntVect, DataStructs.IntSparseIntVect)):
raise ValueError(f"Unsupported embedding type {type(embed)}. Please use either RDKit datastructures, lists, "
f"tuples or one-dimensional numpy arrays.")
raise ValueError(f"Embeddings with non-integer elements are not supported for {method}.")
if not isinstance(embed, (
DataStructs.ExplicitBitVect, DataStructs.LongSparseIntVect, DataStructs.IntSparseIntVect
)):
raise ValueError(
f"Unsupported embedding type {type(embed)}. Please use either RDKit datastructures, lists, "
f"tuples or one-dimensional numpy arrays.")
elif method in get_args(DIST_OPTIONS):
if isinstance(embed, (list, tuple, DataStructs.ExplicitBitVect, DataStructs.LongSparseIntVect, DataStructs.IntSparseIntVect)):
dataset.data = {k: np.array(list(v), dtype=np.float64) for k, v in dataset.data.items()}
if not isinstance(dataset.data[dataset.names[0]], np.ndarray):
raise ValueError(
f"Unsupported embedding type {type(embed)}. Please use either RDKit datastructures, lists, "
f"tuples or one-dimensional numpy arrays.")
else:
raise ValueError(f"Unknown method {method}")
fps = [dataset.data[name] for name in dataset.names]
run(dataset, fps, method)

dataset.cluster_names = copy.deepcopy(dataset.names)
dataset.cluster_map = dict((n, n) for n in dataset.names)


def run(dataset, fps, method):
def scale_min_max(matrix: np.ndarray) -> np.ndarray:
"""
Transform features by scaling each feature to the 0-1 range.
Args:
matrix: The numpy array to be scaled
Returns:
The scaled numpy array
"""
min_val, max_val = np.min(matrix), np.max(matrix)
return (matrix - min_val) / (max_val - min_val)


def run(
dataset: DataSet,
fps: Union[np.ndarray, DataStructs.ExplicitBitVect, DataStructs.LongSparseIntVect,
DataStructs.IntSparseIntVect],
method: Union[SIM_OPTIONS, DIST_OPTIONS],
) -> None:
"""
Compute pairwise similarities of the given fingerprints.
Expand All @@ -124,13 +181,15 @@ def run(dataset, fps, method):
fps: The fingerprints to compute pairwise similarities for.
method: The similarity measure to use.
"""
fct = get_rdkit_fct(method)
dataset.cluster_similarity = np.zeros((len(fps), len(fps)))
for i in range(len(fps)):
dataset.cluster_similarity[i, i] = 1
dataset.cluster_similarity[i, :i] = fct(fps[i], fps[:i])
dataset.cluster_similarity[:i, i] = dataset.cluster_similarity[i, :i]

min_val = np.min(dataset.cluster_similarity)
max_val = np.max(dataset.cluster_similarity)
dataset.cluster_similarity = (dataset.cluster_similarity - min_val) / (max_val - min_val)
if method in get_args(SIM_OPTIONS):
dataset.cluster_similarity = scale_min_max(rdkit_sim(fps, method))
elif method in get_args(DIST_OPTIONS):
if method == "mahalanobis" and len(fps) <= len(fps[0]):
raise ValueError(
f"For clustering with the Mahalanobis method, you have to have more observations that dimensions in "
f"the embeddings. The number of samples ({len(fps)}) is too small; the covariance matrix is singular. "
f"For observations with {len(fps[0])} dimensions, at least {len(fps[0]) + 1} observations are required."
)
dataset.cluster_distance = scale_min_max(scipy.spatial.distance.cdist(
fps, fps, metric={"manhattan": "cityblock", "tanimoto": "jaccard"}.get(method, method)
))
50 changes: 42 additions & 8 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import platform
from pathlib import Path
from typing import get_args

import numpy as np
import pandas as pd
Expand All @@ -16,7 +17,7 @@
from datasail.cluster.mash import run_mash
from datasail.cluster.mmseqs2 import run_mmseqs
from datasail.cluster.mmseqspp import run_mmseqspp
from datasail.cluster.vectors import run_tanimoto
from datasail.cluster.vectors import run_vector, SIM_OPTIONS
from datasail.cluster.tmalign import run_tmalign
from datasail.cluster.wlk import run_wlk

Expand Down Expand Up @@ -254,11 +255,15 @@ def test_mmseqspp_protein():

@pytest.mark.parametrize("algo", ["FP", "MD"])
@pytest.mark.parametrize("in_type", ["Original", "List", "Numpy"])
@pytest.mark.parametrize("method", ["AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Dice", "Kulczynski",
"McConnaughey", "OnBit", "RogotGoldberg", "Russel", "Sokal", "Tanimoto"])
def test_tanimoto(algo, in_type, method, md_calculator):
if algo == "MD" and in_type == "Original":
pytest.skip("Molecular descriptors cannot directly be used as input.")
# @pytest.mark.parametrize("method", ["AllBit", "Asymmetric", "BraunBlanquet", "Cosine", "Dice", "Kulczynski",
# "McConnaughey", "OnBit", "RogotGoldberg", "Russel", "Sokal", "Tanimoto"])
@pytest.mark.parametrize("method", [
"allbit", "asymmetric", "braunblanquet", "cosine", "dice", "kulczynski", "mcconnaughey", "onbit", "rogotgoldberg",
"russel", "sokal",
"canberra", "chebyshev", "cityblock", "euclidean", "hamming", "jaccard",
"mahalanobis", "manhattan", "matching", "minkowski", "seuclidean", "sqeuclidean", "tanimoto"
])
def test_vector(algo, in_type, method, md_calculator):
data = molecule_data()
if algo == "FP":
embed = lambda x: AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024)
Expand All @@ -271,8 +276,37 @@ def test_tanimoto(algo, in_type, method, md_calculator):
else:
wrap = lambda x: np.array(list(x)).astype(int)
data.data = dict((k, wrap(embed(v))) for k, v in data.data.items())
run_tanimoto(data, method)
check_clustering(data)
if (algo == "MD" and in_type == "Original" and method in get_args(SIM_OPTIONS)) or method == "mahalanobis":
with pytest.raises(ValueError):
run_vector(data, method)
else:
run_vector(data, method)
check_clustering(data)


@pytest.mark.parametrize("method", [
"allbit", "asymmetric", "braunblanquet", "cosine", "dice", "kulczynski", "mcconnaughey", "onbit", "rogotgoldberg",
"russel", "sokal",
"canberra", "chebyshev", "cityblock", "euclidean", "hamming", "jaccard",
"mahalanobis", "manhattan", "matching", "minkowski", "sqeuclidean", "tanimoto"
])
def test_vector_edge(method):
dataset = DataSet(
names=["A", "B", "C", "D", "E", "F", "G", "H"],
data={
"A": np.array([1, 1, 1]),
"B": np.array([1, 1, 0]),
"C": np.array([1, 0, 1]),
"D": np.array([0, 1, 1]),
"E": np.array([1, 0, 0]),
"F": np.array([0, 1, 0]),
"G": np.array([0, 0, 1]),
"H": np.array([0, 0, 0]),
},
)
run_vector(dataset, method)
check_clustering(dataset)



@pytest.mark.nowin
Expand Down

0 comments on commit d8b5cee

Please sign in to comment.