From e8744d280eb08d7b91f6dc56336e6fc3eb1e5739 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 22 Sep 2023 11:06:23 +0200 Subject: [PATCH 01/41] Add InvalidInstance and add empty file for wrappers --- scikit_mol/utilities.py | 7 +++++++ scikit_mol/wrapper.py | 0 2 files changed, 7 insertions(+) create mode 100644 scikit_mol/wrapper.py diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 70eac51..24356d4 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -1,7 +1,14 @@ #For a non-scikit-learn check smiles sanitizer class +from typing import NamedTuple import pandas as pd from rdkit import Chem + +class InvalidInstance(NamedTuple): + pipeline_step: str + error: str + + class CheckSmilesSanitazion: def __init__(self, return_mol=False): self.return_mol = return_mol diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py new file mode 100644 index 0000000..e69de29 From 361284bde859a73eed7ab6548922a79e84834d04 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 22 Sep 2023 16:15:45 +0200 Subject: [PATCH 02/41] Nothing works but I want to save commit --- scikit_mol/_invalid.py | 118 +++++++++++++++++++++++++++++++++ scikit_mol/conversions.py | 5 +- scikit_mol/fingerprints.py | 15 +++-- scikit_mol/utilities.py | 7 +- scikit_mol/wrapper.py | 7 ++ tests/fixtures.py | 2 +- tests/test_invalid_handling.py | 25 +++++++ 7 files changed, 164 insertions(+), 15 deletions(-) create mode 100644 scikit_mol/_invalid.py create mode 100644 tests/test_invalid_handling.py diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py new file mode 100644 index 0000000..900dc71 --- /dev/null +++ b/scikit_mol/_invalid.py @@ -0,0 +1,118 @@ +from abc import ABC +from typing import Any, Callable, NamedTuple, Sequence, TypeVar + +import numpy as np +import numpy.typing as npt + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class InvalidInstance(NamedTuple): + pipeline_step: str + error: str + + +class ArrayWithInvalidInstances: + invalid_list: list[InvalidInstance] + + def __init__(self, array_list: list[npt.NDArray[np.int8] | InvalidInstance]): + self.is_valid_array = get_is_valid_array(array_list) + valid_vector_list = filter_by_list(array_list, self.is_valid_array) + self.matrix = np.vstack(valid_vector_list) + self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) + + def __getitem__(self, item: int) -> npt.NDArray[np.int8] | InvalidInstance: + n_invalids_prior = sum(~self.is_valid_array[:item - 1]) + if self.is_valid_array[item]: + return self.matrix[item - n_invalids_prior] + return self.invalid_list[n_invalids_prior + 1] + + def __setitem__(self, key: int, value: npt.NDArray[np.int8] | InvalidInstance) -> None: + n_invalids_prior = sum(~self.is_valid_array[:key - 1]) + if isinstance(value, InvalidInstance): + if self.is_valid_array[key]: + self.matrix = np.delete(self.matrix, key - n_invalids_prior) + self.is_valid_array[key] = False + self.invalid_list.insert(n_invalids_prior + 1, value) + else: + self.invalid_list[n_invalids_prior + 1] = value + else: + if self.is_valid_array[key]: + self.matrix[key - n_invalids_prior] = value + else: + self.matrix = np.insert(self.matrix, key-n_invalids_prior, value) + del(self.invalid_list[n_invalids_prior + 1]) + self.is_valid_array[key] = True + + +def update_list_by( + old_list: list[npt.NDArray[np.int8] | InvalidInstance] | ArrayWithInvalidInstances, + new_values: list[Any], + value_indices: npt.NDArray[np.int_], + ): + old_list = list(old_list) + for new_value, idx in zip(new_values, value_indices, strict=True): + old_list[idx] = new_value + return old_list + + +def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]): + if isinstance(item_list, np.ndarray): + return item_list[is_valid_array] + + item_list_new = [] + for item, is_valid in zip(item_list, is_valid_array): + if is_valid: + item_list_new.append(item) + return item_list_new + +# Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]] +# ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]] +def rdkit_error_handling(func): + def wrapper(obj, *args, **kwargs): + x = args[0] + if isinstance(x, ArrayWithInvalidInstances): + is_valid_array = x.is_valid_array + x_sub = x.matrix + else: + is_valid_array = get_is_valid_array(x) + x_sub = filter_by_list(x, is_valid_array) + if len(args) > 1: + y = args[1] + y_sub = filter_by_list(y, is_valid_array) + else: + y_sub = None + x_new = func(obj, x_sub, y_sub, **kwargs) + new_pos = np.where(is_valid_array)[0] + if isinstance(x, (list, ArrayWithInvalidInstances)): + x_list = update_list_by(x, x_new, new_pos) + else: + x_array = np.array(x) + x_array[is_valid_array] = x_new + x_list = list(x_array) + if isinstance(x_new, ArrayWithInvalidInstances): + return ArrayWithInvalidInstances(x_list) + return x_list + return wrapper + + +def filter_rows( + X: Sequence[_T], y: Sequence[_U] +) -> tuple[Sequence[_T], Sequence[_U]]: + is_valid_array = get_is_valid_array(X) + x_new = filter_by_list(X, is_valid_array) + y_new = filter_by_list(y, is_valid_array) + return x_new, y_new + + +def get_is_valid_array(item_list: Sequence[Any]) -> npt.NDArray[np.bool_]: + is_valid_list = [] + for i, item in enumerate(item_list): + if not isinstance(item, InvalidInstance): + is_valid_list.append(True) + else: + is_valid_list.append(False) + return np.array(is_valid_list, dtype=bool) + + diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index c9d668a..914da37 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -6,6 +6,8 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin +from scikit_mol._invalid import InvalidInstance + class SmilesToMolTransformer(BaseEstimator, TransformerMixin): @@ -56,8 +58,7 @@ def _transform(self, X_smiles_list): if mol: X_out.append(mol) else: - raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first') - + X_out.append(InvalidInstance(str(self), "Invalid Smiles.")) return X_out def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.? diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 81bc43b..36bfdc0 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -16,9 +16,11 @@ from scipy.sparse import vstack from sklearn.base import BaseEstimator, TransformerMixin +from scikit_mol._invalid import ArrayWithInvalidInstances, rdkit_error_handling from abc import ABC, abstractmethod + #%% class FpsTransformer(ABC, BaseEstimator, TransformerMixin): @@ -49,10 +51,10 @@ def fit(self, X, y=None): return self def _transform(self, X): - arr = np.zeros((len(X), self.nBits), dtype=np.int8) + arr_list = [] for i, mol in enumerate(X): - arr[i,:] = self._transform_mol(mol) - return arr + arr_list.append(self._transform_mol(mol)) + return ArrayWithInvalidInstances(arr_list) def _transform_sparse(self, X): arr = np.zeros((len(X), self.nBits), dtype=np.int8) @@ -61,6 +63,7 @@ def _transform_sparse(self, X): return lil_matrix(arr) + @rdkit_error_handling def transform(self, X, y=None): """Transform a list of RDKit molecule objects into a fingerprint array @@ -89,9 +92,9 @@ def transform(self, X, y=None): #arrays = pool.map(self._transform, x_chunks) parameters = self.get_params() arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) - - arr = np.concatenate(arrays) - return arr + arr_list = [] + arr_list.extend(arrays) + return arr_list class MACCSKeysFingerprintTransformer(FpsTransformer): diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 24356d4..866a9aa 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -1,14 +1,9 @@ #For a non-scikit-learn check smiles sanitizer class -from typing import NamedTuple + import pandas as pd from rdkit import Chem -class InvalidInstance(NamedTuple): - pipeline_step: str - error: str - - class CheckSmilesSanitazion: def __init__(self, return_mol=False): self.return_mol = return_mol diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index e69de29..56efe9d 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -0,0 +1,7 @@ +from abc import ABC +from sklearn.base import BaseEstimator + + +class AbstractWrapper(BaseEstimator, ABC): + pass + diff --git a/tests/fixtures.py b/tests/fixtures.py index 4102825..f6aba9e 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -8,7 +8,7 @@ @pytest.fixture def smiles_list(): - return [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in ['O=C(O)c1ccccc1', + return [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in ['O=C(O)c1ccccc1', 'O=C([O-])c1ccccc1', 'O=C([O-])c1ccccc1.[Na+]', 'O=C(O[Na])c1ccccc1', diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py new file mode 100644 index 0000000..fb95060 --- /dev/null +++ b/tests/test_invalid_handling.py @@ -0,0 +1,25 @@ +import pytest +from scikit_mol.conversions import SmilesToMolTransformer +from scikit_mol.fingerprints import MorganFingerprintTransformer +from sklearn.pipeline import Pipeline +from fixtures import smiles_list, invalid_smiles_list +from scikit_mol._invalid import ArrayWithInvalidInstances + +@pytest.fixture +def smilestofp_pipeline(): + pipeline = Pipeline( + [ + ("smiles_to_mol", SmilesToMolTransformer()), + ("mol_2_fp", MorganFingerprintTransformer()), + ] + + ) + return pipeline + + +def test_descriptor_transformer(invalid_smiles_list, smilestofp_pipeline): + smilestofp_pipeline.set_params() + mol_list: ArrayWithInvalidInstances = smilestofp_pipeline.transform(invalid_smiles_list) + print(mol_list.is_valid_array) + print(mol_list.matrix) + print(mol_list.invalid_list) From e5c6a204d5dabd081910cc103ff543f0c86fe2f5 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 22 Sep 2023 17:43:29 +0200 Subject: [PATCH 03/41] first working draft --- scikit_mol/_invalid.py | 61 ++++++++++++++++++++++------------ scikit_mol/fingerprints.py | 4 +-- scikit_mol/wrapper.py | 57 ++++++++++++++++++++++++++++++- tests/test_invalid_handling.py | 16 +++++---- 4 files changed, 107 insertions(+), 31 deletions(-) diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py index 900dc71..d87a460 100644 --- a/scikit_mol/_invalid.py +++ b/scikit_mol/_invalid.py @@ -13,45 +13,52 @@ class InvalidInstance(NamedTuple): error: str -class ArrayWithInvalidInstances: +class NumpyArrayWithInvalidInstances: + is_valid_array: npt.NDArray[np.bool_] invalid_list: list[InvalidInstance] + value_array: npt.NDArray[Any] - def __init__(self, array_list: list[npt.NDArray[np.int8] | InvalidInstance]): + def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]): self.is_valid_array = get_is_valid_array(array_list) valid_vector_list = filter_by_list(array_list, self.is_valid_array) - self.matrix = np.vstack(valid_vector_list) + self.value_array = np.vstack(valid_vector_list) self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) - def __getitem__(self, item: int) -> npt.NDArray[np.int8] | InvalidInstance: + def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance: n_invalids_prior = sum(~self.is_valid_array[:item - 1]) if self.is_valid_array[item]: - return self.matrix[item - n_invalids_prior] + return self.value_array[item - n_invalids_prior] return self.invalid_list[n_invalids_prior + 1] - def __setitem__(self, key: int, value: npt.NDArray[np.int8] | InvalidInstance) -> None: + def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> None: n_invalids_prior = sum(~self.is_valid_array[:key - 1]) if isinstance(value, InvalidInstance): if self.is_valid_array[key]: - self.matrix = np.delete(self.matrix, key - n_invalids_prior) + self.value_array = np.delete(self.value_array, key - n_invalids_prior) self.is_valid_array[key] = False self.invalid_list.insert(n_invalids_prior + 1, value) else: self.invalid_list[n_invalids_prior + 1] = value else: if self.is_valid_array[key]: - self.matrix[key - n_invalids_prior] = value + self.value_array[key - n_invalids_prior] = value else: - self.matrix = np.insert(self.matrix, key-n_invalids_prior, value) + self.value_array = np.insert(self.value_array, key - n_invalids_prior, value) del(self.invalid_list[n_invalids_prior + 1]) self.is_valid_array[key] = True + def array_filled_with(self, fill_value) -> npt.NDArray[Any]: + out = np.full((len(self.is_valid_array), self.value_array.shape[1]), fill_value) + out[self.is_valid_array] = self.value_array + return out -def update_list_by( - old_list: list[npt.NDArray[np.int8] | InvalidInstance] | ArrayWithInvalidInstances, + +def batch_update_sequence( + old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances, new_values: list[Any], value_indices: npt.NDArray[np.int_], ): - old_list = list(old_list) + old_list = list(old_list) # Make shallow copy of list to avoid inplace changes. for new_value, idx in zip(new_values, value_indices, strict=True): old_list[idx] = new_value return old_list @@ -67,32 +74,44 @@ def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]): item_list_new.append(item) return item_list_new + # Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]] # ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]] def rdkit_error_handling(func): def wrapper(obj, *args, **kwargs): x = args[0] - if isinstance(x, ArrayWithInvalidInstances): + if isinstance(x, NumpyArrayWithInvalidInstances): is_valid_array = x.is_valid_array - x_sub = x.matrix + x_sub = x.value_array else: is_valid_array = get_is_valid_array(x) x_sub = filter_by_list(x, is_valid_array) if len(args) > 1: y = args[1] - y_sub = filter_by_list(y, is_valid_array) + if y is not None: + y_sub = filter_by_list(y, is_valid_array) + else: + y_sub = None + x_new = func(obj, x_sub, y_sub, **kwargs) else: - y_sub = None - x_new = func(obj, x_sub, y_sub, **kwargs) + x_new = func(obj, x_sub, **kwargs) + + if x_new is None: # fit may not return anything + return None new_pos = np.where(is_valid_array)[0] - if isinstance(x, (list, ArrayWithInvalidInstances)): - x_list = update_list_by(x, x_new, new_pos) + if isinstance(x_new, np.ndarray) and isinstance(x, NumpyArrayWithInvalidInstances): + if x_new.shape[0] != x.value_array.shape[0]: + raise AssertionError("Numer of rows is not as expected.") + x.value_array = x_new + return x + if isinstance(x, (list, NumpyArrayWithInvalidInstances)): + x_list = batch_update_sequence(x, x_new, new_pos) else: x_array = np.array(x) x_array[is_valid_array] = x_new x_list = list(x_array) - if isinstance(x_new, ArrayWithInvalidInstances): - return ArrayWithInvalidInstances(x_list) + if isinstance(x_new, NumpyArrayWithInvalidInstances): + return NumpyArrayWithInvalidInstances(x_list) return x_list return wrapper diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 36bfdc0..2998136 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -16,7 +16,7 @@ from scipy.sparse import vstack from sklearn.base import BaseEstimator, TransformerMixin -from scikit_mol._invalid import ArrayWithInvalidInstances, rdkit_error_handling +from scikit_mol._invalid import NumpyArrayWithInvalidInstances, rdkit_error_handling from abc import ABC, abstractmethod @@ -54,7 +54,7 @@ def _transform(self, X): arr_list = [] for i, mol in enumerate(X): arr_list.append(self._transform_mol(mol)) - return ArrayWithInvalidInstances(arr_list) + return NumpyArrayWithInvalidInstances(arr_list) def _transform_sparse(self, X): arr = np.zeros((len(X), self.nBits), dtype=np.int8) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 56efe9d..17787af 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -1,7 +1,62 @@ from abc import ABC +from typing import Any + +import numpy as np from sklearn.base import BaseEstimator +from sklearn.pipeline import Pipeline +from sklearn.utils.metaestimators import available_if + +from scikit_mol._invalid import rdkit_error_handling, InvalidInstance, NumpyArrayWithInvalidInstances class AbstractWrapper(BaseEstimator, ABC): - pass + model: BaseEstimator | Pipeline + + def __init__(self, replace_invalid: bool, replace_value=np.nan): + self.replace_invalid = replace_invalid + self.replace_value = replace_value + + @rdkit_error_handling + def fit(self, X, y, **fit_params) -> Any: + return self.model.fit(X, y, **fit_params) + + def has_predict(self) -> bool: + return hasattr(self.model, "predict") + + def has_fit_predict(self) -> bool: + return hasattr(self.model, "fit_predict") + + +class WrappedTransformer(AbstractWrapper): + def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan): + super().__init__(replace_invalid=replace_invalid, replace_value=replace_value) + self.model = model + + def has_transform(self) -> bool: + return hasattr(self.model, "transform") + + def has_fit_transform(self) -> bool: + return hasattr(self.model, "fit_transform") + + @available_if(has_transform) + @rdkit_error_handling + def transform(self, X): + return self.model.transform(X) + + @rdkit_error_handling + def _fit_transform(self, X, y): + return self.model.fit_transform(X, y) + @available_if(has_fit_transform) + def fit_transform(self, X, y=None): + out = self._fit_transform(X,y) + if not self.replace_invalid: + return out + + if isinstance(out, NumpyArrayWithInvalidInstances): + return out.array_filled_with(self.replace_value) + + if isinstance(out, list): + return [self.replace_value if isinstance(v, InvalidInstance) else v for v in out] + + diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index fb95060..42e0344 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -1,9 +1,12 @@ import pytest -from scikit_mol.conversions import SmilesToMolTransformer -from scikit_mol.fingerprints import MorganFingerprintTransformer +from sklearn.decomposition import PCA from sklearn.pipeline import Pipeline + from fixtures import smiles_list, invalid_smiles_list -from scikit_mol._invalid import ArrayWithInvalidInstances +from scikit_mol.conversions import SmilesToMolTransformer +from scikit_mol.fingerprints import MorganFingerprintTransformer +from scikit_mol.wrapper import WrappedTransformer +from scikit_mol._invalid import NumpyArrayWithInvalidInstances @pytest.fixture def smilestofp_pipeline(): @@ -11,6 +14,7 @@ def smilestofp_pipeline(): [ ("smiles_to_mol", SmilesToMolTransformer()), ("mol_2_fp", MorganFingerprintTransformer()), + ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)) ] ) @@ -19,7 +23,5 @@ def smilestofp_pipeline(): def test_descriptor_transformer(invalid_smiles_list, smilestofp_pipeline): smilestofp_pipeline.set_params() - mol_list: ArrayWithInvalidInstances = smilestofp_pipeline.transform(invalid_smiles_list) - print(mol_list.is_valid_array) - print(mol_list.matrix) - print(mol_list.invalid_list) + mol_list: NumpyArrayWithInvalidInstances = smilestofp_pipeline.fit_transform(invalid_smiles_list) + print(mol_list) From 2593ca212a647c280da65fb1c67a38daa4e3b5b8 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 10:41:44 +0200 Subject: [PATCH 04/41] Add docstrings --- scikit_mol/wrapper.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 17787af..1899e27 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -1,3 +1,5 @@ +"""Wrapper for sklearn estimators and pipelines to handle errors.""" + from abc import ABC from typing import Any @@ -10,9 +12,26 @@ class AbstractWrapper(BaseEstimator, ABC): + """ + Abstract class for the wrapper of sklearn objects. + + Attributes + ---------- + model: BaseEstimator | Pipeline + The wrapped model or pipeline. + """ model: BaseEstimator | Pipeline - def __init__(self, replace_invalid: bool, replace_value=np.nan): + def __init__(self, replace_invalid: bool, replace_value: Any = np.nan): + """Initialize the AbstractWrapper. + + Parameters + ---------- + replace_invalid: bool + Whether to replace or remove errors + replace_value: Any, default=np.nan + If replace_invalid==True, insert this value on the erroneous instance. + """ self.replace_invalid = replace_invalid self.replace_value = replace_value @@ -28,7 +47,20 @@ def has_fit_predict(self) -> bool: class WrappedTransformer(AbstractWrapper): + """Wrapper for sklearn transformer objects.""" + def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan): + """Initialize the WrappedTransformer. + + Parameters + ---------- + model: BaseEstimator + Wrapped model to be protected against Errors. + replace_invalid: bool + Whether to replace or remove errors + replace_value: Any, default=np.nan + If replace_invalid==True, insert this value on the erroneous instance. + """ super().__init__(replace_invalid=replace_invalid, replace_value=replace_value) self.model = model @@ -46,6 +78,7 @@ def transform(self, X): @rdkit_error_handling def _fit_transform(self, X, y): return self.model.fit_transform(X, y) + @available_if(has_fit_transform) def fit_transform(self, X, y=None): out = self._fit_transform(X,y) From de5d0d892731323f4ea65eae8d2aa3675f37b291 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:03:18 +0200 Subject: [PATCH 05/41] refactor unittest --- tests/test_invalid_handling.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index 42e0344..9ebaafc 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from sklearn.decomposition import PCA from sklearn.pipeline import Pipeline @@ -7,12 +8,15 @@ from scikit_mol.fingerprints import MorganFingerprintTransformer from scikit_mol.wrapper import WrappedTransformer from scikit_mol._invalid import NumpyArrayWithInvalidInstances +from tests.test_invalid_helpers.invalid_transformer import TestInvalidTransformer + @pytest.fixture def smilestofp_pipeline(): pipeline = Pipeline( [ ("smiles_to_mol", SmilesToMolTransformer()), + ("remove_sulfur", TestInvalidTransformer()), ("mol_2_fp", MorganFingerprintTransformer()), ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)) ] @@ -21,7 +25,20 @@ def smilestofp_pipeline(): return pipeline -def test_descriptor_transformer(invalid_smiles_list, smilestofp_pipeline): +def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): smilestofp_pipeline.set_params() - mol_list: NumpyArrayWithInvalidInstances = smilestofp_pipeline.fit_transform(invalid_smiles_list) - print(mol_list) + mol_pca = smilestofp_pipeline.fit_transform(smiles_list) + error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) + + if mol_pca.shape != (len(smiles_list), 2): + raise ValueError("The PCA does not return the proper dimensions.") + if isinstance(error_mol_pca, NumpyArrayWithInvalidInstances): + raise TypeError("The Errors were not properly remove from the output array.") + + expected_nans = np.array([[0, 0, 1, 1], [0, 1, 0, 1]]) + if not np.all(np.equal(expected_nans, np.where(np.isnan(error_mol_pca)))): + raise ValueError("Errors were replaced on the wrong positions.") + + non_nan_rows = ~np.any(np.isnan(error_mol_pca), axis=1) + if not np.all(np.isclose(mol_pca, error_mol_pca[non_nan_rows, :])): + raise ValueError("Removing errors introduces changes in the PCA output.") From 7ddbfe122153c639b9aed247552552a33d8bb084 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:03:44 +0200 Subject: [PATCH 06/41] add Transformer which can make molecules invalid --- .../invalid_transformer.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 tests/test_invalid_helpers/invalid_transformer.py diff --git a/tests/test_invalid_helpers/invalid_transformer.py b/tests/test_invalid_helpers/invalid_transformer.py new file mode 100644 index 0000000..ffe2738 --- /dev/null +++ b/tests/test_invalid_helpers/invalid_transformer.py @@ -0,0 +1,44 @@ +from typing import Optional, Sequence +from sklearn.base import BaseEstimator, TransformerMixin +from rdkit import Chem + +from scikit_mol._invalid import ( + InvalidInstance, + rdkit_error_handling, +) + + +class TestInvalidTransformer(BaseEstimator, TransformerMixin): + """This class is ment for tesing purposes only. + + All molecules with element number are returned as invalid instance. + + Attributes + --------- + atomic_number_set: set[int] + Atomic numbers which upon occurrence in the molecule make it invalid. + """ + + atomic_number_set: set[int] + + def __init__(self, atomic_number_set: Sequence[int] | None = None) -> None: + if atomic_number_set is None: + atomic_number_set = {16} + self.atomic_number_set = set(atomic_number_set) + + def _transform_mol(self, mol: Chem.Mol) -> Chem.Mol | InvalidInstance: + unique_elements = {atom.GetAtomicNum() for atom in mol.GetAtoms()} + forbidden_elements = self.atomic_number_set & unique_elements + if forbidden_elements: + return InvalidInstance(str(self), f"Molecule contains {forbidden_elements}") + return mol + + @rdkit_error_handling + def transform(self, X: list[Chem.Mol]) -> list[Chem.Mol | InvalidInstance]: + return [self._transform_mol(mol) for mol in X] + + def fit(self, X, y, fit_params): + pass + + def fit_transform(self, X, y=None, **fit_params): + return self.transform(X) From 17f19d4428d028bf078ec78041214cb9750faa52 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:04:10 +0200 Subject: [PATCH 07/41] Resolve merge issues with smiles_list --- tests/fixtures.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index 8cf0751..2852068 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -63,8 +63,9 @@ def chiral_smiles_list(): #Need to be a certain size, so the fingerprints reacts 'N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]']] @pytest.fixture -def invalid_smiles_list(smiles_list): - smiles_list.append('Invalid') +def invalid_smiles_list(): + smiles_list = ['S-CC', 'Invalid'] + smiles_list.extend(_SMILES_LIST) return smiles_list _MOLS_LIST = [Chem.MolFromSmiles(smiles) for smiles in _SMILES_LIST] From 6c1d3e68e16e77ac19414e28a5fa815b14d3c72f Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:04:21 +0200 Subject: [PATCH 08/41] Add docstring --- scikit_mol/_invalid.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py index d87a460..f850adf 100644 --- a/scikit_mol/_invalid.py +++ b/scikit_mol/_invalid.py @@ -9,11 +9,17 @@ class InvalidInstance(NamedTuple): + """ + The InvalidInstance represents objects which raised an error during a pipeline step. + """ pipeline_step: str error: str class NumpyArrayWithInvalidInstances: + """ + The NumpyArrayWithInvalidInstances is + """ is_valid_array: npt.NDArray[np.bool_] invalid_list: list[InvalidInstance] value_array: npt.NDArray[Any] From 591a2ac7afde7c4670a05ec901c1e616d9f2164e Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:05:24 +0200 Subject: [PATCH 09/41] Add init --- tests/test_invalid_helpers/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/test_invalid_helpers/__init__.py diff --git a/tests/test_invalid_helpers/__init__.py b/tests/test_invalid_helpers/__init__.py new file mode 100644 index 0000000..d0c583a --- /dev/null +++ b/tests/test_invalid_helpers/__init__.py @@ -0,0 +1 @@ +"""Initialize module for helper classes and functions used to test the handling of invalid inputs.""" From c713ece6809d4736b7f49980672eae6349447517 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:10:53 +0200 Subject: [PATCH 10/41] Add Message encountering Errors --- scikit_mol/conversions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index db89661..2399eba 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -64,7 +64,8 @@ def _transform(self, X): if mol: X_out.append(mol) else: - X_out.append(InvalidInstance(str(self), "Invalid Smiles.")) + error = Chem.DetectChemistryProblems(mol) + X_out.append(InvalidInstance(str(self), f"Invalid Smiles.: {error}")) return X_out @check_transform_input From 85c9745b83d2e4335d754a2be2cd0b108924a59e Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:54:33 +0200 Subject: [PATCH 11/41] Fix Message encountering Errors --- .idea/.gitignore | 3 +++ .idea/inspectionProfiles/Project_Default.xml | 6 ++++++ .idea/inspectionProfiles/profiles_settings.xml | 6 ++++++ .idea/misc.xml | 10 ++++++++++ .idea/scikit-mol.iml | 13 +++++++++++++ .idea/vcs.xml | 6 ++++++ scikit_mol/conversions.py | 10 ++++++++-- 7 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/scikit-mol.iml create mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..9aa2337 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..13c8595 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/scikit-mol.iml b/.idea/scikit-mol.iml new file mode 100644 index 0000000..fb6d745 --- /dev/null +++ b/.idea/scikit-mol.iml @@ -0,0 +1,13 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 2399eba..b6537e7 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -64,8 +64,14 @@ def _transform(self, X): if mol: X_out.append(mol) else: - error = Chem.DetectChemistryProblems(mol) - X_out.append(InvalidInstance(str(self), f"Invalid Smiles.: {error}")) + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol: + errors = Chem.DetectChemistryProblems(mol) + error_message = "\n".join(error.Message() for error in errors) + message = f"Invalid SMILES: {error_message}" + else: + message = f"Invalid SMILES: {smiles}" + X_out.append(InvalidInstance(str(self), message)) return X_out @check_transform_input From ce55a52f131d87d49ce884a5cb0e37b82bd2442e Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:54:50 +0200 Subject: [PATCH 12/41] Fix reference to test classes --- tests/test_invalid_handling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index 9ebaafc..b848659 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -8,7 +8,7 @@ from scikit_mol.fingerprints import MorganFingerprintTransformer from scikit_mol.wrapper import WrappedTransformer from scikit_mol._invalid import NumpyArrayWithInvalidInstances -from tests.test_invalid_helpers.invalid_transformer import TestInvalidTransformer +from test_invalid_helpers.invalid_transformer import TestInvalidTransformer @pytest.fixture @@ -26,6 +26,7 @@ def smilestofp_pipeline(): def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): + smilestofp_pipeline.set_params() mol_pca = smilestofp_pipeline.fit_transform(smiles_list) error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) From 9e72d676134049b6de53d308f0880586de3fa6d9 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:55:14 +0200 Subject: [PATCH 13/41] Add __len__ to class --- scikit_mol/_invalid.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py index f850adf..e1a91e5 100644 --- a/scikit_mol/_invalid.py +++ b/scikit_mol/_invalid.py @@ -30,6 +30,9 @@ def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]): self.value_array = np.vstack(valid_vector_list) self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) + def __len__(self): + return self.is_valid_array.shape[0] + def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance: n_invalids_prior = sum(~self.is_valid_array[:item - 1]) if self.is_valid_array[item]: From 0c6cf7ee4788947a1a1241e839cd7b657bf9f7a8 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 09:37:22 +0200 Subject: [PATCH 14/41] Simplifying datatypes. Conversions use invalidMol and MACCSkeys are robust to falsy input. --- scikit_mol/_invalid.py | 31 +- scikit_mol/conversions.py | 51 ++- scikit_mol/core.py | 60 ++- scikit_mol/fingerprints.py | 385 ++++++++++++------ scikit_mol/wrapper.py | 18 +- .../invalid_transformer.py | 8 +- 6 files changed, 368 insertions(+), 185 deletions(-) diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py index e1a91e5..7130216 100644 --- a/scikit_mol/_invalid.py +++ b/scikit_mol/_invalid.py @@ -12,6 +12,7 @@ class InvalidInstance(NamedTuple): """ The InvalidInstance represents objects which raised an error during a pipeline step. """ + pipeline_step: str error: str @@ -20,6 +21,7 @@ class NumpyArrayWithInvalidInstances: """ The NumpyArrayWithInvalidInstances is """ + is_valid_array: npt.NDArray[np.bool_] invalid_list: list[InvalidInstance] value_array: npt.NDArray[Any] @@ -34,13 +36,13 @@ def __len__(self): return self.is_valid_array.shape[0] def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance: - n_invalids_prior = sum(~self.is_valid_array[:item - 1]) + n_invalids_prior = sum(~self.is_valid_array[: item - 1]) if self.is_valid_array[item]: return self.value_array[item - n_invalids_prior] return self.invalid_list[n_invalids_prior + 1] def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> None: - n_invalids_prior = sum(~self.is_valid_array[:key - 1]) + n_invalids_prior = sum(~self.is_valid_array[: key - 1]) if isinstance(value, InvalidInstance): if self.is_valid_array[key]: self.value_array = np.delete(self.value_array, key - n_invalids_prior) @@ -52,8 +54,10 @@ def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> No if self.is_valid_array[key]: self.value_array[key - n_invalids_prior] = value else: - self.value_array = np.insert(self.value_array, key - n_invalids_prior, value) - del(self.invalid_list[n_invalids_prior + 1]) + self.value_array = np.insert( + self.value_array, key - n_invalids_prior, value + ) + del self.invalid_list[n_invalids_prior + 1] self.is_valid_array[key] = True def array_filled_with(self, fill_value) -> npt.NDArray[Any]: @@ -63,10 +67,10 @@ def array_filled_with(self, fill_value) -> npt.NDArray[Any]: def batch_update_sequence( - old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances, - new_values: list[Any], - value_indices: npt.NDArray[np.int_], - ): + old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances, + new_values: list[Any], + value_indices: npt.NDArray[np.int_], +): old_list = list(old_list) # Make shallow copy of list to avoid inplace changes. for new_value, idx in zip(new_values, value_indices, strict=True): old_list[idx] = new_value @@ -108,7 +112,9 @@ def wrapper(obj, *args, **kwargs): if x_new is None: # fit may not return anything return None new_pos = np.where(is_valid_array)[0] - if isinstance(x_new, np.ndarray) and isinstance(x, NumpyArrayWithInvalidInstances): + if isinstance(x_new, np.ndarray) and isinstance( + x, NumpyArrayWithInvalidInstances + ): if x_new.shape[0] != x.value_array.shape[0]: raise AssertionError("Numer of rows is not as expected.") x.value_array = x_new @@ -122,12 +128,11 @@ def wrapper(obj, *args, **kwargs): if isinstance(x_new, NumpyArrayWithInvalidInstances): return NumpyArrayWithInvalidInstances(x_list) return x_list + return wrapper -def filter_rows( - X: Sequence[_T], y: Sequence[_U] -) -> tuple[Sequence[_T], Sequence[_U]]: +def filter_rows(X: Sequence[_T], y: Sequence[_U]) -> tuple[Sequence[_T], Sequence[_U]]: is_valid_array = get_is_valid_array(X) x_new = filter_by_list(X, is_valid_array) y_new = filter_by_list(y, is_valid_array) @@ -142,5 +147,3 @@ def get_is_valid_array(item_list: Sequence[Any]) -> npt.NDArray[np.bool_]: else: is_valid_list.append(False) return np.array(is_valid_list, dtype=bool) - - diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index b6537e7..00afddb 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -6,16 +6,20 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin -from scikit_mol.core import check_transform_input, feature_names_default_mol ,DEFAULT_MOL_COLUMN_NAME +from scikit_mol.core import ( + check_transform_input, + feature_names_default_mol, + DEFAULT_MOL_COLUMN_NAME, + InvalidMol, +) -from scikit_mol._invalid import InvalidInstance +# from scikit_mol._invalid import InvalidMol class SmilesToMolTransformer(BaseEstimator, TransformerMixin): - def __init__(self, parallel: Union[bool, int] = False): self.parallel = parallel - self.start_method = None #TODO implement handling of start_method + self.start_method = None # TODO implement handling of start_method @feature_names_default_mol def get_feature_names_out(self, input_features=None): @@ -43,18 +47,25 @@ def transform(self, X_smiles_list, y=None): ValueError Raises ValueError if a SMILES string is unparsable by RDKit """ - if not self.parallel: return self._transform(X_smiles_list) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes*2 if n_processes is not None else multiprocessing.cpu_count()*2 #TODO, tune the number of chunks per child process + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes * 2 + if n_processes is not None + else multiprocessing.cpu_count() * 2 + ) # TODO, tune the number of chunks per child process with get_context(self.start_method).Pool(processes=n_processes) as pool: - x_chunks = np.array_split(X_smiles_list, n_chunks) - arrays = pool.map(self._transform, x_chunks) #is the helper function a safer way of handling the picklind and child process communication - arr = np.concatenate(arrays) - return arr + x_chunks = np.array_split(X_smiles_list, n_chunks) + arrays = pool.map( + self._transform, x_chunks + ) # is the helper function a safer way of handling the picklind and child process communication + arr = np.concatenate(arrays) + return arr @check_transform_input def _transform(self, X): @@ -71,15 +82,23 @@ def _transform(self, X): message = f"Invalid SMILES: {error_message}" else: message = f"Invalid SMILES: {smiles}" - X_out.append(InvalidInstance(str(self), message)) + X_out.append(InvalidMol(str(self), message)) return X_out @check_transform_input - def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.? + def inverse_transform( + self, X_mols_list, y=None + ): # TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.? X_out = [] for mol in X_mols_list: - smiles = Chem.MolToSmiles(mol) - X_out.append(smiles) + if mol: + try: + smiles = Chem.MolToSmiles(mol) + X_out.append(smiles) + except Exception as e: + X_out.append(InvalidMol(str(self), str(e))) + else: + X_out.append(InvalidMol(str(self), f"Not a Mol: {mol}")) - return np.array(X_out).reshape(-1,1) + return np.array(X_out).reshape(-1, 1) diff --git a/scikit_mol/core.py b/scikit_mol/core.py index 66685a6..9b13680 100644 --- a/scikit_mol/core.py +++ b/scikit_mol/core.py @@ -5,6 +5,7 @@ Users who want to create their own transformers should use this module. """ +from dataclasses import dataclass import functools import numpy as np @@ -16,48 +17,71 @@ DEFAULT_MOL_COLUMN_NAME = "ROMol" +@dataclass +class InvalidMol: + """ + Represents molecules which raised an error during a pipeline step. + Evaluates to False in boolean contexts. + """ + + pipeline_step: str + error: str + + def __bool__(self): + return False + + def __repr__(self): + return f"InvalidMol('{self.pipeline_step}', error='{self.error}')" + + def _validate_transform_input(X): - """Validate and adapt the input of the _transform method""" - try: - shape = X.shape - except AttributeError: - # If X is not array-like or dataframe-like, - # we just return it as is, so users can use simple lists and sequences. - return X - # If X is an array-like or dataframe-like, we make sure it is compatible with - # the scikit-learn API, and that it contains a single column: - # scikit-mol transformers need a single column with smiles or mols. - if len(shape) == 1: - return X # Flatt Arrays and list-like data are also supported #TODO, add a warning about non-2D data if logging is implemented - if shape[1] != 1: - raise ValueError("Only one column supported. You may want to use a ColumnTransformer https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html ") - return np.array(X).flatten() + """Validate and adapt the input of the _transform method""" + try: + shape = X.shape + except AttributeError: + # If X is not array-like or dataframe-like, + # we just return it as is, so users can use simple lists and sequences. + return X + # If X is an array-like or dataframe-like, we make sure it is compatible with + # the scikit-learn API, and that it contains a single column: + # scikit-mol transformers need a single column with smiles or mols. + if len(shape) == 1: + return X # Flatt Arrays and list-like data are also supported #TODO, add a warning about non-2D data if logging is implemented + if shape[1] != 1: + raise ValueError( + "Only one column supported. You may want to use a ColumnTransformer https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html " + ) + return np.array(X).flatten() + def check_transform_input(method): """ Decorator to check the input of the _transform method and make it compatible with the scikit-learn API and with downstream methods. """ + @functools.wraps(method) def wrapper(obj, X): X = _validate_transform_input(X) - result = method(obj, X) + result = method(obj, X) # If the output of the _transform method # must be changed depending on the initial type of X, do it here. return result return wrapper + def feature_names_default_mol(method): """ Decorator that returns the default feature names for the mol object """ + @functools.wraps(method) def wrapper(obj, input_features=None): prefix = DEFAULT_MOL_COLUMN_NAME if input_features is not None: - return np.array([f'{prefix}_{name}' for name in input_features]) + return np.array([f"{prefix}_{name}" for name in input_features]) else: return np.array([prefix]) - return wrapper \ No newline at end of file + return wrapper diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 65b08f8..80a3123 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -1,11 +1,12 @@ -#%% +# %% from multiprocessing import Pool, get_context import multiprocessing import re from typing import Union from rdkit import Chem from rdkit import DataStructs -#from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect + +# from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect from rdkit.Chem import rdMolDescriptors from rdkit.Chem import rdFingerprintGenerator from rdkit.Chem import rdMHFPFingerprint @@ -23,14 +24,16 @@ from abc import ABC, abstractmethod -_PATTERN_FINGERPRINT_TRANSFORMER = re.compile(r"^(?P\w+)FingerprintTransformer$") +_PATTERN_FINGERPRINT_TRANSFORMER = re.compile( + r"^(?P\w+)FingerprintTransformer$" +) -#%% -class FpsTransformer(ABC, BaseEstimator, TransformerMixin): +# %% +class FpsTransformer(ABC, BaseEstimator, TransformerMixin): def __init__(self, parallel: Union[bool, int] = False, start_method: str = None): self.parallel = parallel - self.start_method = start_method #TODO implement handling of start_method + self.start_method = start_method # TODO implement handling of start_method # The dtype of the fingerprint array computed by the transformer # If needed this property can be overwritten in the child class. @@ -56,7 +59,9 @@ def get_display_feature_names_out(self, input_features=None): """ prefix = self._get_column_prefix() n_digits = self._get_n_digits_column_suffix() - return np.array([f"{prefix}_{str(i).zfill(n_digits)}" for i in range(1, self.nBits + 1)]) + return np.array( + [f"{prefix}_{str(i).zfill(n_digits)}" for i in range(1, self.nBits + 1)] + ) def get_feature_names_out(self, input_features=None): """Get feature names for fingerprint transformers @@ -77,7 +82,10 @@ def _mol2fp(self, mol): def _fp2array(self, fp): arr = np.zeros((self.nBits,), dtype=self._DTYPE_FINGERPRINT) - DataStructs.ConvertToNumpyArray(fp, arr) + if fp: + DataStructs.ConvertToNumpyArray(fp, arr) + else: + arr[:] = np.nan # Sadly, dtype=int8 does not allow for NaN values return arr def _transform_mol(self, mol): @@ -94,19 +102,18 @@ def fit(self, X, y=None): @check_transform_input def _transform(self, X): - arr_list = [] + arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) for i, mol in enumerate(X): - arr_list.append(self._transform_mol(mol)) - return NumpyArrayWithInvalidInstances(arr_list) + arr[i, :] = self._transform_mol(mol) + return arr def _transform_sparse(self, X): arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) for i, mol in enumerate(X): - arr[i,:] = self._transform_mol(mol) - + arr[i, :] = self._transform_mol(mol) + return lil_matrix(arr) - @rdkit_error_handling def transform(self, X, y=None): """Transform a list of RDKit molecule objects into a fingerprint array @@ -126,29 +133,41 @@ def transform(self, X, y=None): return self._transform(X) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes if n_processes is not None else multiprocessing.cpu_count() - + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes if n_processes is not None else multiprocessing.cpu_count() + ) + with get_context(self.start_method).Pool(processes=n_processes) as pool: x_chunks = np.array_split(X, n_chunks) - #TODO check what is fastest, pickle or recreate and do this only for classes that need this - #arrays = pool.map(self._transform, x_chunks) + # TODO check what is fastest, pickle or recreate and do this only for classes that need this + # arrays = pool.map(self._transform, x_chunks) parameters = self.get_params() # TODO: create "transform_parallel" function in the core module, # and use it here and in the descriptors transformer - #x_chunks = [np.array(x).reshape(-1, 1) for x in x_chunks] - arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) - arr_list = [] - arr_list.extend(arrays) - return arr_list + # x_chunks = [np.array(x).reshape(-1, 1) for x in x_chunks] + arrays = pool.map( + parallel_helper, + [ + (self.__class__.__name__, parameters, x_chunk) + for x_chunk in x_chunks + ], + ) + + arr = np.concatenate(arrays) + return arr class MACCSKeysFingerprintTransformer(FpsTransformer): + _DTYPE_FINGERPRINT = float + def __init__(self, parallel: Union[bool, int] = False): """MACCS keys fingerprinter calculates the 167 fixed MACCS keys """ - super().__init__(parallel = parallel) + super().__init__(parallel=parallel) self.nBits = 167 @property @@ -158,20 +177,33 @@ def nBits(self): @nBits.setter def nBits(self, nBits): if nBits != 167: - raise ValueError("nBits can only be 167, matching the number of defined MACCS keys!") + raise ValueError( + "nBits can only be 167, matching the number of defined MACCS keys!" + ) self._nBits = nBits def _mol2fp(self, mol): - return rdMolDescriptors.GetMACCSKeysFingerprint( - mol - ) + if mol: + return rdMolDescriptors.GetMACCSKeysFingerprint(mol) + else: + return False + class RDKitFingerprintTransformer(FpsTransformer): - def __init__(self, minPath:int = 1, maxPath:int =7, useHs:bool = True, branchedPaths:bool = True, - useBondOrder:bool = True, countSimulation:bool = False, countBounds = None, - fpSize:int = 2048, numBitsPerFeature:int = 2, atomInvariantsGenerator = None, - parallel: Union[bool, int] = False - ): + def __init__( + self, + minPath: int = 1, + maxPath: int = 7, + useHs: bool = True, + branchedPaths: bool = True, + useBondOrder: bool = True, + countSimulation: bool = False, + countBounds=None, + fpSize: int = 2048, + numBitsPerFeature: int = 2, + atomInvariantsGenerator=None, + parallel: Union[bool, int] = False, + ): """Calculates the RDKit fingerprints Parameters @@ -197,7 +229,7 @@ def __init__(self, minPath:int = 1, maxPath:int =7, useHs:bool = True, branchedP atomInvariantsGenerator : _type_, optional atom invariants to be used during fingerprint generation, by default None """ - super().__init__(parallel = parallel) + super().__init__(parallel=parallel) self.minPath = minPath self.maxPath = maxPath self.useHs = useHs @@ -213,27 +245,46 @@ def __init__(self, minPath:int = 1, maxPath:int =7, useHs:bool = True, branchedP def fpSize(self): return self.nBits - #Scikit-Learn expects to be able to set fpSize directly on object via .set_params(), so this updates nBits used by the abstract class + # Scikit-Learn expects to be able to set fpSize directly on object via .set_params(), so this updates nBits used by the abstract class @fpSize.setter def fpSize(self, fpSize): self.nBits = fpSize def _mol2fp(self, mol): - generator = rdFingerprintGenerator.GetRDKitFPGenerator(minPath=int(self.minPath), maxPath=int(self.maxPath), - useHs=bool(self.useHs), branchedPaths=bool(self.branchedPaths), - useBondOrder=bool(self.useBondOrder), - countSimulation=bool(self.countSimulation), - countBounds=bool(self.countBounds), fpSize=int(self.fpSize), - numBitsPerFeature=int(self.numBitsPerFeature), - atomInvariantsGenerator=self.atomInvariantsGenerator - ) + generator = rdFingerprintGenerator.GetRDKitFPGenerator( + minPath=int(self.minPath), + maxPath=int(self.maxPath), + useHs=bool(self.useHs), + branchedPaths=bool(self.branchedPaths), + useBondOrder=bool(self.useBondOrder), + countSimulation=bool(self.countSimulation), + countBounds=bool(self.countBounds), + fpSize=int(self.fpSize), + numBitsPerFeature=int(self.numBitsPerFeature), + atomInvariantsGenerator=self.atomInvariantsGenerator, + ) return generator.GetFingerprint(mol) -class AtomPairFingerprintTransformer(FpsTransformer): #FIXME, some of the init arguments seems to be molecule specific, and should probably not be setable? - def __init__(self, minLength:int = 1, maxLength:int = 30, fromAtoms = 0, ignoreAtoms = 0, atomInvariants = 0, - nBitsPerEntry:int = 4, includeChirality:bool = False, use2D:bool = True, confId:int = -1, nBits=2048, - useCounts:bool=False, parallel: Union[bool, int] = False,): - super().__init__(parallel = parallel) + +class AtomPairFingerprintTransformer( + FpsTransformer +): # FIXME, some of the init arguments seems to be molecule specific, and should probably not be setable? + def __init__( + self, + minLength: int = 1, + maxLength: int = 30, + fromAtoms=0, + ignoreAtoms=0, + atomInvariants=0, + nBitsPerEntry: int = 4, + includeChirality: bool = False, + use2D: bool = True, + confId: int = -1, + nBits=2048, + useCounts: bool = False, + parallel: Union[bool, int] = False, + ): + super().__init__(parallel=parallel) self.minLength = minLength self.maxLength = maxLength self.fromAtoms = fromAtoms @@ -248,34 +299,48 @@ def __init__(self, minLength:int = 1, maxLength:int = 30, fromAtoms = 0, ignoreA def _mol2fp(self, mol): if self.useCounts: - return rdMolDescriptors.GetHashedAtomPairFingerprint(mol, nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId) - ) + return rdMolDescriptors.GetHashedAtomPairFingerprint( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) else: - return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - nBitsPerEntry=int(self.nBitsPerEntry), - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId) - ) + return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + nBitsPerEntry=int(self.nBitsPerEntry), + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) + class TopologicalTorsionFingerprintTransformer(FpsTransformer): - def __init__(self, targetSize:int = 4, fromAtoms = 0, ignoreAtoms = 0, atomInvariants = 0, - includeChirality:bool = False, nBitsPerEntry:int = 4, nBits=2048, - useCounts:bool=False, parallel: Union[bool, int] = False): - super().__init__(parallel = parallel) + def __init__( + self, + targetSize: int = 4, + fromAtoms=0, + ignoreAtoms=0, + atomInvariants=0, + includeChirality: bool = False, + nBitsPerEntry: int = 4, + nBits=2048, + useCounts: bool = False, + parallel: Union[bool, int] = False, + ): + super().__init__(parallel=parallel) self.targetSize = targetSize self.fromAtoms = fromAtoms self.ignoreAtoms = ignoreAtoms @@ -287,27 +352,41 @@ def __init__(self, targetSize:int = 4, fromAtoms = 0, ignoreAtoms = 0, atomInvar def _mol2fp(self, mol): if self.useCounts: - return rdMolDescriptors.GetHashedTopologicalTorsionFingerprint(mol, nBits=int(self.nBits), - targetSize=int(self.targetSize), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - ) + return rdMolDescriptors.GetHashedTopologicalTorsionFingerprint( + mol, + nBits=int(self.nBits), + targetSize=int(self.targetSize), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + ) else: - return rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=int(self.nBits), - targetSize=int(self.targetSize), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - nBitsPerEntry=int(self.nBitsPerEntry) - ) + return rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect( + mol, + nBits=int(self.nBits), + targetSize=int(self.targetSize), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + nBitsPerEntry=int(self.nBitsPerEntry), + ) + class MHFingerprintTransformer(FpsTransformer): # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 - def __init__(self, radius:int=3, rings:bool=True, isomeric:bool=False, kekulize:bool=False, - min_radius:int=1, n_permutations:int=2048, seed:int=42, parallel: Union[bool, int] = False,): + def __init__( + self, + radius: int = 3, + rings: bool = True, + isomeric: bool = False, + kekulize: bool = False, + min_radius: int = 1, + n_permutations: int = 2048, + seed: int = 42, + parallel: Union[bool, int] = False, + ): """Transforms the RDKit mol into the MinHash fingerprint (MHFP) Args: @@ -316,17 +395,17 @@ def __init__(self, radius:int=3, rings:bool=True, isomeric:bool=False, kekulize: isomeric (bool, optional): Whether the isomeric SMILES to be considered. Defaults to False. kekulize (bool, optional): Whether or not to kekulize the extracted SMILES. Defaults to False. min_radius (int, optional): The minimum radius that is used to extract n-gram. Defaults to 1. - n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0, + n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0, this is effectively the length of the FP seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel = parallel) + super().__init__(parallel=parallel) self.radius = radius self.rings = rings self.isomeric = isomeric self.kekulize = kekulize self.min_radius = min_radius - #Set the .n_permutations and .seed without creating the encoder twice + # Set the .n_permutations and .seed without creating the encoder twice self._n_permutations = n_permutations self._seed = seed # create the encoder instance @@ -336,7 +415,7 @@ def __getstate__(self): # Get the state of the parent class state = super().__getstate__() # Remove the unpicklable property from the state - state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable + state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable return state def __setstate__(self, state): @@ -348,14 +427,18 @@ def __setstate__(self, state): _DTYPE_FINGERPRINT = np.int32 def _mol2fp(self, mol): - fp = self.mhfp_encoder.EncodeMol(mol, self.radius, self.rings, self.isomeric, self.kekulize, self.min_radius) + fp = self.mhfp_encoder.EncodeMol( + mol, self.radius, self.rings, self.isomeric, self.kekulize, self.min_radius + ) return fp - + def _fp2array(self, fp): return np.array(fp) def _recreate_encoder(self): - self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder(self._n_permutations, self._seed) + self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder( + self._n_permutations, self._seed + ) @property def seed(self): @@ -382,10 +465,21 @@ def nBits(self): # to be compliant with the requirement of the base class return self._n_permutations + class SECFingerprintTransformer(FpsTransformer): # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 - def __init__(self, radius:int=3, rings:bool=True, isomeric:bool=False, kekulize:bool=False, - min_radius:int=1, length:int=2048, n_permutations:int=0, seed:int=0, parallel: Union[bool, int] = False,): + def __init__( + self, + radius: int = 3, + rings: bool = True, + isomeric: bool = False, + kekulize: bool = False, + min_radius: int = 1, + length: int = 2048, + n_permutations: int = 0, + seed: int = 0, + parallel: Union[bool, int] = False, + ): """Transforms the RDKit mol into the SMILES extended connectivity fingerprint (SECFP) Args: @@ -398,14 +492,14 @@ def __init__(self, radius:int=3, rings:bool=True, isomeric:bool=False, kekulize: n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0. seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel = parallel) + super().__init__(parallel=parallel) self.radius = radius self.rings = rings self.isomeric = isomeric self.kekulize = kekulize self.min_radius = min_radius self.length = length - #Set the .n_permutations and seed without creating the encoder twice + # Set the .n_permutations and seed without creating the encoder twice self._n_permutations = n_permutations self._seed = seed # create the encoder instance @@ -415,7 +509,7 @@ def __getstate__(self): # Get the state of the parent class state = super().__getstate__() # Remove the unpicklable property from the state - state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable + state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable return state def __setstate__(self, state): @@ -425,10 +519,20 @@ def __setstate__(self, state): self._recreate_encoder() def _mol2fp(self, mol): - return self.mhfp_encoder.EncodeSECFPMol(mol, self.radius, self.rings, self.isomeric, self.kekulize, self.min_radius, self.length) + return self.mhfp_encoder.EncodeSECFPMol( + mol, + self.radius, + self.rings, + self.isomeric, + self.kekulize, + self.min_radius, + self.length, + ) def _recreate_encoder(self): - self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder(self._n_permutations, self._seed) + self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder( + self._n_permutations, self._seed + ) @property def seed(self): @@ -455,8 +559,18 @@ def nBits(self): # to be compliant with the requirement of the base class return self.length + class MorganFingerprintTransformer(FpsTransformer): - def __init__(self, nBits=2048, radius=2, useChirality=False, useBondTypes=True, useFeatures=False, useCounts=False, parallel: Union[bool, int] = False,): + def __init__( + self, + nBits=2048, + radius=2, + useChirality=False, + useBondTypes=True, + useFeatures=False, + useCounts=False, + parallel: Union[bool, int] = False, + ): """Transform RDKit mols into Count or bit-based hashed MorganFingerprints Parameters @@ -474,7 +588,7 @@ def __init__(self, nBits=2048, radius=2, useChirality=False, useBondTypes=True, useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel = parallel) + super().__init__(parallel=parallel) self.nBits = nBits self.radius = radius self.useChirality = useChirality @@ -485,19 +599,36 @@ def __init__(self, nBits=2048, radius=2, useChirality=False, useBondTypes=True, def _mol2fp(self, mol): if self.useCounts: return rdMolDescriptors.GetHashedMorganFingerprint( - mol,int(self.radius),nBits=int(self.nBits), useFeatures=bool(self.useFeatures), - useChirality=bool(self.useChirality), useBondTypes=bool(self.useBondTypes) + mol, + int(self.radius), + nBits=int(self.nBits), + useFeatures=bool(self.useFeatures), + useChirality=bool(self.useChirality), + useBondTypes=bool(self.useBondTypes), ) else: return rdMolDescriptors.GetMorganFingerprintAsBitVect( - mol,int(self.radius),nBits=int(self.nBits), useFeatures=bool(self.useFeatures), - useChirality=bool(self.useChirality), useBondTypes=bool(self.useBondTypes) + mol, + int(self.radius), + nBits=int(self.nBits), + useFeatures=bool(self.useFeatures), + useChirality=bool(self.useChirality), + useBondTypes=bool(self.useBondTypes), ) - + + class AvalonFingerprintTransformer(FpsTransformer): # Fingerprint from the Avalon toolkeit, https://doi.org/10.1021/ci050413p - def __init__(self, nBits:int = 512, isQuery:bool = False, resetVect:bool = False, bitFlags:int = 15761407, useCounts:bool = False, parallel: Union[bool, int] = False,): - """ Transform RDKit mols into Count or bit-based Avalon Fingerprints + def __init__( + self, + nBits: int = 512, + isQuery: bool = False, + resetVect: bool = False, + bitFlags: int = 15761407, + useCounts: bool = False, + parallel: Union[bool, int] = False, + ): + """Transform RDKit mols into Count or bit-based Avalon Fingerprints Parameters ---------- @@ -512,26 +643,28 @@ def __init__(self, nBits:int = 512, isQuery:bool = False, resetVect:bool = False useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel = parallel) + super().__init__(parallel=parallel) self.nBits = nBits self.isQuery = isQuery self.resetVect = resetVect self.bitFlags = bitFlags self.useCounts = useCounts - + def _mol2fp(self, mol): if self.useCounts: - return pyAvalonTools.GetAvalonCountFP(mol, - nBits=int(self.nBits), - isQuery=bool(self.isQuery), - bitFlags=int(self.bitFlags) + return pyAvalonTools.GetAvalonCountFP( + mol, + nBits=int(self.nBits), + isQuery=bool(self.isQuery), + bitFlags=int(self.bitFlags), ) else: - return pyAvalonTools.GetAvalonFP(mol, - nBits=int(self.nBits), - isQuery=bool(self.isQuery), - resetVect=bool(self.resetVect), - bitFlags=int(self.bitFlags) + return pyAvalonTools.GetAvalonFP( + mol, + nBits=int(self.nBits), + isQuery=bool(self.isQuery), + resetVect=bool(self.resetVect), + bitFlags=int(self.bitFlags), ) @@ -541,6 +674,6 @@ def parallel_helper(args): Intention is to be able to do this in chilcprocesses as some classes can't be pickled""" classname, parameters, X_mols = args from scikit_mol import fingerprints + transformer = getattr(fingerprints, classname)(**parameters) return transformer._transform(X_mols) - diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 1899e27..bd2bfe5 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -8,7 +8,11 @@ from sklearn.pipeline import Pipeline from sklearn.utils.metaestimators import available_if -from scikit_mol._invalid import rdkit_error_handling, InvalidInstance, NumpyArrayWithInvalidInstances +from scikit_mol._invalid import ( + rdkit_error_handling, + InvalidMol, + NumpyArrayWithInvalidInstances, +) class AbstractWrapper(BaseEstimator, ABC): @@ -20,6 +24,7 @@ class AbstractWrapper(BaseEstimator, ABC): model: BaseEstimator | Pipeline The wrapped model or pipeline. """ + model: BaseEstimator | Pipeline def __init__(self, replace_invalid: bool, replace_value: Any = np.nan): @@ -49,7 +54,9 @@ def has_fit_predict(self) -> bool: class WrappedTransformer(AbstractWrapper): """Wrapper for sklearn transformer objects.""" - def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan): + def __init__( + self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan + ): """Initialize the WrappedTransformer. Parameters @@ -81,7 +88,7 @@ def _fit_transform(self, X, y): @available_if(has_fit_transform) def fit_transform(self, X, y=None): - out = self._fit_transform(X,y) + out = self._fit_transform(X, y) if not self.replace_invalid: return out @@ -89,7 +96,4 @@ def fit_transform(self, X, y=None): return out.array_filled_with(self.replace_value) if isinstance(out, list): - return [self.replace_value if isinstance(v, InvalidInstance) else v for v in out] - - - + return [self.replace_value if isinstance(v, InvalidMol) else v for v in out] diff --git a/tests/test_invalid_helpers/invalid_transformer.py b/tests/test_invalid_helpers/invalid_transformer.py index ffe2738..fb0add7 100644 --- a/tests/test_invalid_helpers/invalid_transformer.py +++ b/tests/test_invalid_helpers/invalid_transformer.py @@ -3,7 +3,7 @@ from rdkit import Chem from scikit_mol._invalid import ( - InvalidInstance, + InvalidMol, rdkit_error_handling, ) @@ -26,15 +26,15 @@ def __init__(self, atomic_number_set: Sequence[int] | None = None) -> None: atomic_number_set = {16} self.atomic_number_set = set(atomic_number_set) - def _transform_mol(self, mol: Chem.Mol) -> Chem.Mol | InvalidInstance: + def _transform_mol(self, mol: Chem.Mol) -> Chem.Mol | InvalidMol: unique_elements = {atom.GetAtomicNum() for atom in mol.GetAtoms()} forbidden_elements = self.atomic_number_set & unique_elements if forbidden_elements: - return InvalidInstance(str(self), f"Molecule contains {forbidden_elements}") + return InvalidMol(str(self), f"Molecule contains {forbidden_elements}") return mol @rdkit_error_handling - def transform(self, X: list[Chem.Mol]) -> list[Chem.Mol | InvalidInstance]: + def transform(self, X: list[Chem.Mol]) -> list[Chem.Mol | InvalidMol]: return [self._transform_mol(mol) for mol in X] def fit(self, X, y, fit_params): From 79f5dab720010ae77abab1efa3103a2c22f2dc32 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 10:25:45 +0200 Subject: [PATCH 15/41] Added the simplified NanGuardWrapper. Some tests fails, especially around some of the tests with pandas output. Need more analysis --- scikit_mol/conversions.py | 2 +- scikit_mol/wrapper.py | 110 +++++++++++++++++- ...nvalid_handling.py => invalid_handling.py} | 4 +- 3 files changed, 111 insertions(+), 5 deletions(-) rename tests/{test_invalid_handling.py => invalid_handling.py} (99%) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 00afddb..34c4dad 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -83,7 +83,7 @@ def _transform(self, X): else: message = f"Invalid SMILES: {smiles}" X_out.append(InvalidMol(str(self), message)) - return X_out + return np.array(X_out).reshape(-1, 1) @check_transform_input def inverse_transform( diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index bd2bfe5..6cf6817 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -4,13 +4,15 @@ from typing import Any import numpy as np +from functools import wraps +import warnings from sklearn.base import BaseEstimator from sklearn.pipeline import Pipeline from sklearn.utils.metaestimators import available_if from scikit_mol._invalid import ( rdkit_error_handling, - InvalidMol, + # InvalidMol, NumpyArrayWithInvalidInstances, ) @@ -97,3 +99,109 @@ def fit_transform(self, X, y=None): if isinstance(out, list): return [self.replace_value if isinstance(v, InvalidMol) else v for v in out] + + +def filter_invalid_rows(fill_value=np.nan, warn_on_invalid=False): + def decorator(func): + @wraps(func) + def wrapper(obj, X, *args, **kwargs): + valid_mask = np.isfinite(X).all(axis=1) + + if warn_on_invalid and not np.all(valid_mask): + warnings.warn( + f"Invalid data detected in {func.__name__}. This may lead to unexpected results.", + UserWarning, + ) + + valid_indices = np.where(valid_mask)[0] + reduced_X = X[valid_mask] + + result = func(obj, reduced_X, *args, **kwargs) + + if result is None: # For methods like fit that return None + return None + + if isinstance(result, np.ndarray): + output = np.full((X.shape[0], result.shape[1]), fill_value) + output[valid_indices] = result + return output + else: + return result # For methods that return non-array results + + return wrapper + + return decorator + + +class NanGuardWrapper(BaseEstimator): + """Nan/Inf safe wrapper for sklearn estimator objects.""" + + def __init__( + self, + estimator: BaseEstimator, + replace_invalid: bool = True, + replace_value=np.nan, + ): + super().__init__() + self.replace_invalid = replace_invalid + self.replace_value = replace_value + self.estimator = estimator + + def has_predict(self) -> bool: + return hasattr(self.estimator, "predict") + + def has_predict_proba(self) -> bool: + return hasattr(self.estimator, "predict_proba") + + def has_transform(self) -> bool: + return hasattr(self.estimator, "transform") + + def has_fit_transform(self) -> bool: + return hasattr(self.estimator, "fit_transform") + + def has_score(self) -> bool: + return hasattr(self.estimator, "score") + + def has_n_features_in_(self) -> bool: + return hasattr(self.estimator, "n_features_in_") + + def has_decision_function(self) -> bool: + return hasattr(self.estimator, "decision_function") + + @property + def n_features_in_(self) -> int: + return self.estimator.n_features_in_ + + @filter_invalid_rows(warn_on_invalid=True) + def fit(self, X, *args, **fit_params) -> Any: + return self.estimator.fit(X, *args, **fit_params) + + @available_if(has_predict) + @filter_invalid_rows() + def predict(self, X): + return self.estimator.predict(X) + + @available_if(has_decision_function) + @filter_invalid_rows() + def decision_function(self, X): + return self.estimator.decision_function(X) + + @available_if(has_predict_proba) + @filter_invalid_rows() + def predict_proba(self, X): + return self.estimator.predict_proba(X) + + @available_if(has_transform) + @filter_invalid_rows() + def transform(self, X): + return self.estimator.transform(X) + + @available_if(has_fit_transform) + @filter_invalid_rows(warn_on_invalid=True) + def fit_transform(self, X, y): + return self.estimator.fit_transform(X, y) + + @available_if(has_score) + @filter_invalid_rows() + def score(self, X, y): + return self.estimator.score(X, y) diff --git a/tests/test_invalid_handling.py b/tests/invalid_handling.py similarity index 99% rename from tests/test_invalid_handling.py rename to tests/invalid_handling.py index b848659..884154b 100644 --- a/tests/test_invalid_handling.py +++ b/tests/invalid_handling.py @@ -18,15 +18,13 @@ def smilestofp_pipeline(): ("smiles_to_mol", SmilesToMolTransformer()), ("remove_sulfur", TestInvalidTransformer()), ("mol_2_fp", MorganFingerprintTransformer()), - ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)) + ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)), ] - ) return pipeline def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): - smilestofp_pipeline.set_params() mol_pca = smilestofp_pipeline.fit_transform(smiles_list) error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) From 7bbd8f10794fdcf88794025cec94cf77549af237 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 10:46:33 +0200 Subject: [PATCH 16/41] Added support for pandas output. However, the set_output(), does not work on the wrapper class?? but it works if set on e.g. a pca instance. --- scikit_mol/wrapper.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 6cf6817..501c1d9 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -4,11 +4,13 @@ from typing import Any import numpy as np +import pandas as pd from functools import wraps import warnings from sklearn.base import BaseEstimator from sklearn.pipeline import Pipeline from sklearn.utils.metaestimators import available_if +from sklearn.base import TransformerMixin from scikit_mol._invalid import ( rdkit_error_handling, @@ -125,6 +127,12 @@ def wrapper(obj, X, *args, **kwargs): output = np.full((X.shape[0], result.shape[1]), fill_value) output[valid_indices] = result return output + elif isinstance(result, pd.DataFrame): + # Create a DataFrame with NaN values for all rows + output = pd.DataFrame(index=range(X.shape[0]), columns=result.columns) + # Fill the valid rows with the result data + output.iloc[valid_indices] = result + return output else: return result # For methods that return non-array results @@ -133,7 +141,7 @@ def wrapper(obj, X, *args, **kwargs): return decorator -class NanGuardWrapper(BaseEstimator): +class NanGuardWrapper(BaseEstimator, TransformerMixin): """Nan/Inf safe wrapper for sklearn estimator objects.""" def __init__( From 716c898e42d82276adf6d42a36d3894f7a901df8 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 12:08:32 +0200 Subject: [PATCH 17/41] implemented the handle_errors flag on smilestomol transformer --- scikit_mol/conversions.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 34c4dad..c18d1bb 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -17,9 +17,10 @@ class SmilesToMolTransformer(BaseEstimator, TransformerMixin): - def __init__(self, parallel: Union[bool, int] = False): + def __init__(self, parallel: Union[bool, int] = False, handle_errors: bool = False): self.parallel = parallel self.start_method = None # TODO implement handling of start_method + self.handle_errors = handle_errors @feature_names_default_mol def get_feature_names_out(self, input_features=None): @@ -83,22 +84,33 @@ def _transform(self, X): else: message = f"Invalid SMILES: {smiles}" X_out.append(InvalidMol(str(self), message)) + if not self.handle_errors and not all(X_out): + fails = [x for x in X_out if not x] + raise ValueError( + f"Invalid SMILES found: {fails}." + ) # TODO with this appraoch we get all errors, but we do process ALL the smiles first which could be slow return np.array(X_out).reshape(-1, 1) @check_transform_input - def inverse_transform( - self, X_mols_list, y=None - ): # TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.? + def inverse_transform(self, X_mols_list, y=None): X_out = [] for mol in X_mols_list: - if mol: + if isinstance(mol, Chem.Mol): try: smiles = Chem.MolToSmiles(mol) X_out.append(smiles) except Exception as e: - X_out.append(InvalidMol(str(self), str(e))) + X_out.append( + InvalidMol( + str(self), f"Error converting Mol to SMILES: {str(e)}" + ) + ) else: X_out.append(InvalidMol(str(self), f"Not a Mol: {mol}")) + if not self.handle_errors and not all(isinstance(x, str) for x in X_out): + fails = [x for x in X_out if not isinstance(x, str)] + raise ValueError(f"Invalid Mols found: {fails}.") + return np.array(X_out).reshape(-1, 1) From 0077ae607e4029384fa79bee4b1fa671324be07d Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 13:40:01 +0200 Subject: [PATCH 18/41] Fixed an error in the tests of the sanitizer --- tests/fixtures.py | 112 +++++++++++++++++++++++++++------------- tests/test_sanitizer.py | 41 +++++++++++---- 2 files changed, 107 insertions(+), 46 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index 2852068..11636d3 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -10,29 +10,40 @@ from sklearn.preprocessing import FunctionTransformer from sklearn.pipeline import make_pipeline from sklearn.compose import make_column_selector, make_column_transformer -from scikit_mol.fingerprints import MACCSKeysFingerprintTransformer, RDKitFingerprintTransformer, AtomPairFingerprintTransformer, \ - TopologicalTorsionFingerprintTransformer, MorganFingerprintTransformer, SECFingerprintTransformer, \ - MHFingerprintTransformer, AvalonFingerprintTransformer +from scikit_mol.fingerprints import ( + MACCSKeysFingerprintTransformer, + RDKitFingerprintTransformer, + AtomPairFingerprintTransformer, + TopologicalTorsionFingerprintTransformer, + MorganFingerprintTransformer, + SECFingerprintTransformer, + MHFingerprintTransformer, + AvalonFingerprintTransformer, +) from scikit_mol.descriptors import MolecularDescriptorTransformer from scikit_mol.conversions import SmilesToMolTransformer from scikit_mol.standardizer import Standardizer from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT, DEFAULT_MOL_COLUMN_NAME -#TODO these should really go into the conftest.py, so that they are automatically imported in the tests +# TODO these should really go into the conftest.py, so that they are automatically imported in the tests _SMILES_LIST = [ - 'O=C(O)c1ccccc1', - 'O=C([O-])c1ccccc1', - 'O=C([O-])c1ccccc1.[Na+]', - 'O=C(O[Na])c1ccccc1', - 'C[N+](C)C.O=C([O-])c1ccccc1', + "O=C(O)c1ccccc1", + "O=C([O-])c1ccccc1", + "O=C([O-])c1ccccc1.[Na+]", + "O=C(O[Na])c1ccccc1", + "C[N+](C)C.O=C([O-])c1ccccc1", +] +_CANONICAL_SMILES_LIST = [ + Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in _SMILES_LIST ] -_CANONICAL_SMILES_LIST = [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in _SMILES_LIST] + @pytest.fixture def smiles_list(): return _CANONICAL_SMILES_LIST.copy() + _CONTAINER_CREATORS = [ lambda x: x, lambda x: np.array(x), @@ -48,36 +59,51 @@ def smiles_list(): ] for name in _names_to_test: _CONTAINER_CREATORS.append(lambda x, name=name: pd.Series(x, name=name)) - _CONTAINER_CREATORS.append(lambda x, name=name: pd.DataFrame({name: x}) if name else pd.DataFrame(x)) + _CONTAINER_CREATORS.append( + lambda x, name=name: pd.DataFrame({name: x}) if name else pd.DataFrame(x) + ) + -@pytest.fixture(params=[container(_CANONICAL_SMILES_LIST) for container in _CONTAINER_CREATORS] +@pytest.fixture( + params=[container(_CANONICAL_SMILES_LIST) for container in _CONTAINER_CREATORS] ) -def smiles_container(request, ): +def smiles_container( + request, +): return request.param.copy() + @pytest.fixture -def chiral_smiles_list(): #Need to be a certain size, so the fingerprints reacts to different max_lenǵths and radii - return [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in [ - 'N[C@@H](C)C(=O)OCCCCCCCCCCCC', - 'C1C[C@H]2CCCC[C@H]2CC1CCCCCCCCC', - 'N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]']] +def chiral_smiles_list(): # Need to be a certain size, so the fingerprints reacts to different max_lenǵths and radii + return [ + Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) + for smiles in [ + "N[C@@H](C)C(=O)OCCCCCCCCCCCC", + "C1C[C@H]2CCCC[C@H]2CC1CCCCCCCCC", + "N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]", + ] + ] + @pytest.fixture -def invalid_smiles_list(): - smiles_list = ['S-CC', 'Invalid'] - smiles_list.extend(_SMILES_LIST) +def invalid_smiles_list(smiles_list): + smiles_list.append("Invalid") return smiles_list + _MOLS_LIST = [Chem.MolFromSmiles(smiles) for smiles in _SMILES_LIST] + @pytest.fixture def mols_list(): return _MOLS_LIST.copy() + @pytest.fixture(params=[container(_MOLS_LIST) for container in _CONTAINER_CREATORS]) def mols_container(request): return request.param.copy() + @pytest.fixture def chiral_mols_list(chiral_smiles_list): return [Chem.MolFromSmiles(smiles) for smiles in chiral_smiles_list] @@ -85,41 +111,57 @@ def chiral_mols_list(chiral_smiles_list): @pytest.fixture def fingerprint(mols_list): - return rdMolDescriptors.GetHashedMorganFingerprint(mols_list[0],2,nBits=1000) + return rdMolDescriptors.GetHashedMorganFingerprint(mols_list[0], 2, nBits=1000) + _DIR_DATA = Path(__file__).parent / "data" _FILE_SLC6A4 = _DIR_DATA / "SLC6A4_active_excapedb_subset.csv" _FILE_SLC6A4_WITH_CDDD = _DIR_DATA / "CDDD_SLC6A4_active_excapedb_subset.csv.gz" + @pytest.fixture def SLC6A4_subset(): data = pd.read_csv(_FILE_SLC6A4) return data + @pytest.fixture def SLC6A4_subset_with_cddd(SLC6A4_subset): data = SLC6A4_subset.copy().drop_duplicates(subset="Ambit_InchiKey") cddd = pd.read_csv(_FILE_SLC6A4_WITH_CDDD, index_col="Ambit_InchiKey") - data = data.merge(cddd, left_on="Ambit_InchiKey", right_index=True, how="inner", validate="one_to_one") + data = data.merge( + cddd, + left_on="Ambit_InchiKey", + right_index=True, + how="inner", + validate="one_to_one", + ) return data -skip_pandas_output_test = pytest.mark.skipif(Version(sklearn.__version__) < SKLEARN_VERSION_PANDAS_OUT, reason=f"requires scikit-learn {SKLEARN_VERSION_PANDAS_OUT} or higher") + +skip_pandas_output_test = pytest.mark.skipif( + Version(sklearn.__version__) < SKLEARN_VERSION_PANDAS_OUT, + reason=f"requires scikit-learn {SKLEARN_VERSION_PANDAS_OUT} or higher", +) _FEATURIZER_CLASSES = [ - MACCSKeysFingerprintTransformer, - RDKitFingerprintTransformer, - AtomPairFingerprintTransformer, - TopologicalTorsionFingerprintTransformer, - MorganFingerprintTransformer, - SECFingerprintTransformer, - MHFingerprintTransformer, - AvalonFingerprintTransformer, - MolecularDescriptorTransformer, - ] + MACCSKeysFingerprintTransformer, + RDKitFingerprintTransformer, + AtomPairFingerprintTransformer, + TopologicalTorsionFingerprintTransformer, + MorganFingerprintTransformer, + SECFingerprintTransformer, + MHFingerprintTransformer, + AvalonFingerprintTransformer, + MolecularDescriptorTransformer, +] + + @pytest.fixture(params=_FEATURIZER_CLASSES) def featurizer(request): return request.param() + @pytest.fixture def combined_transformer(featurizer): descriptors_pipeline = make_pipeline( @@ -137,4 +179,4 @@ def combined_transformer(featurizer): (identity_pipeline, make_column_selector(pattern=r"^cddd_\d+$")), remainder="drop", ) - return transformer \ No newline at end of file + return transformer diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py index def82a0..a69d597 100644 --- a/tests/test_sanitizer.py +++ b/tests/test_sanitizer.py @@ -5,43 +5,62 @@ from fixtures import smiles_list, invalid_smiles_list from scikit_mol.utilities import CheckSmilesSanitazion + @pytest.fixture def sanitizer(): return CheckSmilesSanitazion() + @pytest.fixture def return_mol_sanitizer(): return CheckSmilesSanitazion(return_mol=True) + def test_checksmilessanitation(smiles_list, invalid_smiles_list, sanitizer): smiles_list_sanitized, errors = sanitizer.sanitize(invalid_smiles_list) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] + def test_checksmilessanitation_x_and_y(smiles_list, invalid_smiles_list, sanitizer): - smiles_list_sanitized, y_sanitized, errors, y_errors = sanitizer.sanitize(smiles_list, list(range(len(smiles_list)))) + smiles_list_sanitized, y_sanitized, errors, y_errors = sanitizer.sanitize( + invalid_smiles_list, list(range(len(invalid_smiles_list))) + ) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] - #Test that y is correctly split into y_error and the rest - assert all([ a == b for a, b in zip(y_sanitized, list(range(len(smiles_list) -1 )))]) - assert y_errors[0] == len(smiles_list)-1 #Last smiles is invalid + # Test that y is correctly split into y_error and the rest + assert all([a == b for a, b in zip(y_sanitized, list(range(len(smiles_list) - 1)))]) + assert y_errors[0] == len(smiles_list) - 1 # Last smiles is invalid + def test_checksmilessanitation_np(smiles_list, invalid_smiles_list, sanitizer): smiles_list_sanitized, errors = sanitizer.sanitize(np.array(invalid_smiles_list)) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] + def test_checksmilessanitation_numpy(smiles_list, invalid_smiles_list, sanitizer): smiles_list_sanitized, errors = sanitizer.sanitize(pd.Series(invalid_smiles_list)) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] -def test_checksmilessanitation_return_mol(smiles_list, invalid_smiles_list, return_mol_sanitizer): + +def test_checksmilessanitation_return_mol( + smiles_list, invalid_smiles_list, return_mol_sanitizer +): smiles_list_sanitized, errors = return_mol_sanitizer.sanitize(invalid_smiles_list) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, [Chem.MolToSmiles(smiles) for smiles in smiles_list_sanitized])]) - assert errors[0] == return_mol_sanitizer.errors.SMILES[0] \ No newline at end of file + assert all( + [ + a == b + for a, b in zip( + smiles_list, + [Chem.MolToSmiles(smiles) for smiles in smiles_list_sanitized], + ) + ] + ) + assert errors[0] == return_mol_sanitizer.errors.SMILES[0] From 3a195e67809c1a6ed4260f58385390f53315ec70 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 14:35:11 +0200 Subject: [PATCH 19/41] cleanup of some accidentially added hidden files --- .idea/.gitignore | 3 --- .idea/inspectionProfiles/Project_Default.xml | 6 ------ .idea/inspectionProfiles/profiles_settings.xml | 6 ------ .idea/misc.xml | 10 ---------- .idea/scikit-mol.iml | 13 ------------- .idea/vcs.xml | 6 ------ 6 files changed, 44 deletions(-) delete mode 100644 .idea/.gitignore delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/scikit-mol.iml delete mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 26d3352..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 9aa2337..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index 13c8595..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,10 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/scikit-mol.iml b/.idea/scikit-mol.iml deleted file mode 100644 index fb6d745..0000000 --- a/.idea/scikit-mol.iml +++ /dev/null @@ -1,13 +0,0 @@ - - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file From cf98fd85bc54879eef28d7353da86aae02f12ac0 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 15:11:59 +0200 Subject: [PATCH 20/41] Added a basic test of an error_handling pipeline --- scikit_mol/fingerprints.py | 1 - scikit_mol/wrapper.py | 8 ++++++-- tests/fixtures.py | 1 + tests/{invalid_handling.py => test_invalid_handling.py} | 0 4 files changed, 7 insertions(+), 3 deletions(-) rename tests/{invalid_handling.py => test_invalid_handling.py} (100%) diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 80a3123..146ede2 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -19,7 +19,6 @@ from sklearn.base import BaseEstimator, TransformerMixin from scikit_mol.core import check_transform_input -from scikit_mol._invalid import NumpyArrayWithInvalidInstances, rdkit_error_handling from abc import ABC, abstractmethod diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 501c1d9..a1bf825 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -107,6 +107,10 @@ def filter_invalid_rows(fill_value=np.nan, warn_on_invalid=False): def decorator(func): @wraps(func) def wrapper(obj, X, *args, **kwargs): + if not getattr(obj, "handle_errors", True): + # If handle_errors is False, call the original function without filtering + return func(obj, X, *args, **kwargs) + valid_mask = np.isfinite(X).all(axis=1) if warn_on_invalid and not np.all(valid_mask): @@ -147,11 +151,11 @@ class NanGuardWrapper(BaseEstimator, TransformerMixin): def __init__( self, estimator: BaseEstimator, - replace_invalid: bool = True, + handle_errors: bool = True, replace_value=np.nan, ): super().__init__() - self.replace_invalid = replace_invalid + self.handle_errors = handle_errors self.replace_value = replace_value self.estimator = estimator diff --git a/tests/fixtures.py b/tests/fixtures.py index 11636d3..57bcf60 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -87,6 +87,7 @@ def chiral_smiles_list(): # Need to be a certain size, so the fingerprints reac @pytest.fixture def invalid_smiles_list(smiles_list): + smiles_list = smiles_list.copy() smiles_list.append("Invalid") return smiles_list diff --git a/tests/invalid_handling.py b/tests/test_invalid_handling.py similarity index 100% rename from tests/invalid_handling.py rename to tests/test_invalid_handling.py From bd3b262069388008e631fd1b63a39bb9b3543e97 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 15:12:20 +0200 Subject: [PATCH 21/41] formatting changes --- tests/test_invalid_handling.py | 42 ++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index 884154b..fda9442 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -5,39 +5,47 @@ from fixtures import smiles_list, invalid_smiles_list from scikit_mol.conversions import SmilesToMolTransformer -from scikit_mol.fingerprints import MorganFingerprintTransformer -from scikit_mol.wrapper import WrappedTransformer -from scikit_mol._invalid import NumpyArrayWithInvalidInstances -from test_invalid_helpers.invalid_transformer import TestInvalidTransformer +from scikit_mol.fingerprints import ( + MorganFingerprintTransformer, + MACCSKeysFingerprintTransformer, +) +from scikit_mol.wrapper import NanGuardWrapper # WrappedTransformer + +# from scikit_mol._invalid import NumpyArrayWithInvalidInstances +# from test_invalid_helpers.invalid_transformer import TestInvalidTransformer @pytest.fixture def smilestofp_pipeline(): pipeline = Pipeline( [ - ("smiles_to_mol", SmilesToMolTransformer()), - ("remove_sulfur", TestInvalidTransformer()), - ("mol_2_fp", MorganFingerprintTransformer()), - ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)), + ("smiles_to_mol", SmilesToMolTransformer(handle_errors=True)), + ("mol_2_fp", MACCSKeysFingerprintTransformer()), + ("PCA", NanGuardWrapper(PCA(2), handle_errors=True)), ] ) return pipeline def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): - smilestofp_pipeline.set_params() + # smilestofp_pipeline.set_params() mol_pca = smilestofp_pipeline.fit_transform(smiles_list) error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) - if mol_pca.shape != (len(smiles_list), 2): - raise ValueError("The PCA does not return the proper dimensions.") - if isinstance(error_mol_pca, NumpyArrayWithInvalidInstances): - raise TypeError("The Errors were not properly remove from the output array.") + print(mol_pca.shape) + assert mol_pca.shape == ( + len(smiles_list), + 2, + ), "The PCA does not return the proper dimensions." - expected_nans = np.array([[0, 0, 1, 1], [0, 1, 0, 1]]) - if not np.all(np.equal(expected_nans, np.where(np.isnan(error_mol_pca)))): + expected_nans = np.array([[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]]).T + if not np.all(np.equal(expected_nans, np.isnan(error_mol_pca))): raise ValueError("Errors were replaced on the wrong positions.") non_nan_rows = ~np.any(np.isnan(error_mol_pca), axis=1) - if not np.all(np.isclose(mol_pca, error_mol_pca[non_nan_rows, :])): - raise ValueError("Removing errors introduces changes in the PCA output.") + assert np.all( + np.isclose(mol_pca, error_mol_pca[non_nan_rows, :]) + ), "Removing errors introduces changes in the PCA output." + + # TODO, test with and without error handling on + # TODO, test with other transformers From bb5c50610fce9ed8c4a919cf93e2d86960c7ef40 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 15:12:43 +0200 Subject: [PATCH 22/41] Cleanup --- tests/test_invalid_helpers/__init__.py | 1 - .../invalid_transformer.py | 44 ------------------- 2 files changed, 45 deletions(-) delete mode 100644 tests/test_invalid_helpers/__init__.py delete mode 100644 tests/test_invalid_helpers/invalid_transformer.py diff --git a/tests/test_invalid_helpers/__init__.py b/tests/test_invalid_helpers/__init__.py deleted file mode 100644 index d0c583a..0000000 --- a/tests/test_invalid_helpers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Initialize module for helper classes and functions used to test the handling of invalid inputs.""" diff --git a/tests/test_invalid_helpers/invalid_transformer.py b/tests/test_invalid_helpers/invalid_transformer.py deleted file mode 100644 index fb0add7..0000000 --- a/tests/test_invalid_helpers/invalid_transformer.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Optional, Sequence -from sklearn.base import BaseEstimator, TransformerMixin -from rdkit import Chem - -from scikit_mol._invalid import ( - InvalidMol, - rdkit_error_handling, -) - - -class TestInvalidTransformer(BaseEstimator, TransformerMixin): - """This class is ment for tesing purposes only. - - All molecules with element number are returned as invalid instance. - - Attributes - --------- - atomic_number_set: set[int] - Atomic numbers which upon occurrence in the molecule make it invalid. - """ - - atomic_number_set: set[int] - - def __init__(self, atomic_number_set: Sequence[int] | None = None) -> None: - if atomic_number_set is None: - atomic_number_set = {16} - self.atomic_number_set = set(atomic_number_set) - - def _transform_mol(self, mol: Chem.Mol) -> Chem.Mol | InvalidMol: - unique_elements = {atom.GetAtomicNum() for atom in mol.GetAtoms()} - forbidden_elements = self.atomic_number_set & unique_elements - if forbidden_elements: - return InvalidMol(str(self), f"Molecule contains {forbidden_elements}") - return mol - - @rdkit_error_handling - def transform(self, X: list[Chem.Mol]) -> list[Chem.Mol | InvalidMol]: - return [self._transform_mol(mol) for mol in X] - - def fit(self, X, y, fit_params): - pass - - def fit_transform(self, X, y=None, **fit_params): - return self.transform(X) From 91fff950eefd8164c81794eaa7f327e7b80ac1d9 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 16:06:30 +0200 Subject: [PATCH 23/41] Added error handling to fingerprint classes. Also added a utility to switch the handle_errors settings recursively on a sklearn pipeline or similar. --- scikit_mol/fingerprints.py | 114 +++++++++++++++++++-------------- scikit_mol/utilities.py | 80 +++++++++++++++++++++-- scikit_mol/wrapper.py | 2 +- tests/test_invalid_handling.py | 2 +- 4 files changed, 142 insertions(+), 56 deletions(-) diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 146ede2..38569fb 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -30,9 +30,15 @@ # %% class FpsTransformer(ABC, BaseEstimator, TransformerMixin): - def __init__(self, parallel: Union[bool, int] = False, start_method: str = None): + def __init__( + self, + parallel: Union[bool, int] = False, + start_method: str = None, + handle_errors: bool = False, + ): self.parallel = parallel - self.start_method = start_method # TODO implement handling of start_method + self.start_method = start_method + self.handle_errors = handle_errors # The dtype of the fingerprint array computed by the transformer # If needed this property can be overwritten in the child class. @@ -73,7 +79,7 @@ def get_feature_names_out(self, input_features=None): @abstractmethod def _mol2fp(self, mol): - """Generate descriptor from mol + """Generate fingerprint from mol MUST BE OVERWRITTEN """ @@ -88,9 +94,16 @@ def _fp2array(self, fp): return arr def _transform_mol(self, mol): - fp = self._mol2fp(mol) - arr = self._fp2array(fp) - return arr + if not mol and self.handle_errors: + return self._fp2array(False) + try: + fp = self._mol2fp(mol) + return self._fp2array(fp) + except Exception as e: + if self.handle_errors: + return self._fp2array(False) + else: + raise e def fit(self, X, y=None): """Included for scikit-learn compatibility @@ -162,11 +175,11 @@ def transform(self, X, y=None): class MACCSKeysFingerprintTransformer(FpsTransformer): _DTYPE_FINGERPRINT = float - def __init__(self, parallel: Union[bool, int] = False): + def __init__(self, parallel: Union[bool, int] = False, handle_errors: bool = False): """MACCS keys fingerprinter calculates the 167 fixed MACCS keys """ - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.nBits = 167 @property @@ -182,10 +195,7 @@ def nBits(self, nBits): self._nBits = nBits def _mol2fp(self, mol): - if mol: - return rdMolDescriptors.GetMACCSKeysFingerprint(mol) - else: - return False + return rdMolDescriptors.GetMACCSKeysFingerprint(mol) class RDKitFingerprintTransformer(FpsTransformer): @@ -202,6 +212,7 @@ def __init__( numBitsPerFeature: int = 2, atomInvariantsGenerator=None, parallel: Union[bool, int] = False, + handle_errors: bool = False, ): """Calculates the RDKit fingerprints @@ -228,7 +239,7 @@ def __init__( atomInvariantsGenerator : _type_, optional atom invariants to be used during fingerprint generation, by default None """ - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.minPath = minPath self.maxPath = maxPath self.useHs = useHs @@ -265,9 +276,7 @@ def _mol2fp(self, mol): return generator.GetFingerprint(mol) -class AtomPairFingerprintTransformer( - FpsTransformer -): # FIXME, some of the init arguments seems to be molecule specific, and should probably not be setable? +class AtomPairFingerprintTransformer(FpsTransformer): def __init__( self, minLength: int = 1, @@ -282,8 +291,9 @@ def __init__( nBits=2048, useCounts: bool = False, parallel: Union[bool, int] = False, + handle_errors: bool = False, ): - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.minLength = minLength self.maxLength = maxLength self.fromAtoms = fromAtoms @@ -297,33 +307,36 @@ def __init__( self.useCounts = useCounts def _mol2fp(self, mol): - if self.useCounts: - return rdMolDescriptors.GetHashedAtomPairFingerprint( - mol, - nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId), - ) + if mol: + if self.useCounts: + return rdMolDescriptors.GetHashedAtomPairFingerprint( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) + else: + return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + nBitsPerEntry=int(self.nBitsPerEntry), + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) else: - return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect( - mol, - nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - nBitsPerEntry=int(self.nBitsPerEntry), - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId), - ) + return False class TopologicalTorsionFingerprintTransformer(FpsTransformer): @@ -338,8 +351,9 @@ def __init__( nBits=2048, useCounts: bool = False, parallel: Union[bool, int] = False, + handle_errors: bool = False, ): - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.targetSize = targetSize self.fromAtoms = fromAtoms self.ignoreAtoms = ignoreAtoms @@ -385,6 +399,7 @@ def __init__( n_permutations: int = 2048, seed: int = 42, parallel: Union[bool, int] = False, + handle_errors: bool = False, ): """Transforms the RDKit mol into the MinHash fingerprint (MHFP) @@ -398,7 +413,7 @@ def __init__( this is effectively the length of the FP seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.radius = radius self.rings = rings self.isomeric = isomeric @@ -478,6 +493,7 @@ def __init__( n_permutations: int = 0, seed: int = 0, parallel: Union[bool, int] = False, + handle_errors: bool = False, ): """Transforms the RDKit mol into the SMILES extended connectivity fingerprint (SECFP) @@ -491,7 +507,7 @@ def __init__( n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0. seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.radius = radius self.rings = rings self.isomeric = isomeric @@ -569,6 +585,7 @@ def __init__( useFeatures=False, useCounts=False, parallel: Union[bool, int] = False, + handle_errors: bool = False, ): """Transform RDKit mols into Count or bit-based hashed MorganFingerprints @@ -587,7 +604,7 @@ def __init__( useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.nBits = nBits self.radius = radius self.useChirality = useChirality @@ -626,6 +643,7 @@ def __init__( bitFlags: int = 15761407, useCounts: bool = False, parallel: Union[bool, int] = False, + handle_errors: bool = False, ): """Transform RDKit mols into Count or bit-based Avalon Fingerprints @@ -642,7 +660,7 @@ def __init__( useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, handle_errors=handle_errors) self.nBits = nBits self.isQuery = isQuery self.resetVect = resetVect diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 866a9aa..481116a 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -1,14 +1,19 @@ -#For a non-scikit-learn check smiles sanitizer class +# For a non-scikit-learn check smiles sanitizer class import pandas as pd from rdkit import Chem +from sklearn.base import BaseEstimator +from sklearn.pipeline import Pipeline, FeatureUnion +from sklearn.compose import ColumnTransformer +import warnings + class CheckSmilesSanitazion: def __init__(self, return_mol=False): self.return_mol = return_mol self.errors = pd.DataFrame() - + def sanitize(self, X_smiles_list, y=None): if y: y_out = [] @@ -29,9 +34,11 @@ def sanitize(self, X_smiles_list, y=None): y_errors.append(y_value) if X_errors: - print(f'Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors') + print( + f"Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors" + ) - self.errors = pd.DataFrame({'SMILES':X_errors, 'y':y_errors}) + self.errors = pd.DataFrame({"SMILES": X_errors, "y": y_errors}) return X_out, y_out, X_errors, y_errors @@ -50,8 +57,69 @@ def sanitize(self, X_smiles_list, y=None): X_errors.append(smiles) if X_errors: - print(f'Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors') + print( + f"Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors" + ) - self.errors = pd.DataFrame({'SMILES':X_errors}) + self.errors = pd.DataFrame({"SMILES": X_errors}) return X_out, X_errors + + +def set_handle_errors(estimator, value): + """ + Recursively set the handle_errors parameter for all compatible estimators. + + :param estimator: A scikit-learn estimator, pipeline, or custom wrapper + :param value: Boolean value to set for handle_errors + """ + + def _set_handle_errors_recursive(est, val): + if hasattr(est, "handle_errors"): + est.handle_errors = val + + # Handle Pipeline + if isinstance(est, Pipeline): + for _, step in est.steps: + _set_handle_errors_recursive(step, val) + + # Handle FeatureUnion + elif isinstance(est, FeatureUnion): + for _, transformer in est.transformer_list: + _set_handle_errors_recursive(transformer, val) + + # Handle ColumnTransformer + elif isinstance(est, ColumnTransformer): + for _, transformer, _ in est.transformers: + _set_handle_errors_recursive(transformer, val) + + # Handle NanGuardWrapper + elif hasattr(est, "estimator") and isinstance(est.estimator, BaseEstimator): + _set_handle_errors_recursive(est.estimator, val) + + # Handle other estimators with get_params + elif isinstance(est, BaseEstimator): + params = est.get_params(deep=False) + for param_name, param_value in params.items(): + if isinstance(param_value, BaseEstimator): + _set_handle_errors_recursive(param_value, val) + + # Apply the recursive function + _set_handle_errors_recursive(estimator, value) + + # Final check + params = estimator.get_params(deep=True) + mismatched_params = [ + key.rstrip("__handle_errors") + for key, val in params.items() + if key.endswith("__handle_errors") and val != value + ] + + if mismatched_params: + warnings.warn( + f"The following components have 'handle_errors' set to a different value than requested: {mismatched_params}. " + "This could be due to nested estimators that were not properly handled.", + UserWarning, + ) + + return estimator diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index a1bf825..13bda6c 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -151,7 +151,7 @@ class NanGuardWrapper(BaseEstimator, TransformerMixin): def __init__( self, estimator: BaseEstimator, - handle_errors: bool = True, + handle_errors: bool = False, replace_value=np.nan, ): super().__init__() diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index fda9442..dfa327b 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -20,7 +20,7 @@ def smilestofp_pipeline(): pipeline = Pipeline( [ ("smiles_to_mol", SmilesToMolTransformer(handle_errors=True)), - ("mol_2_fp", MACCSKeysFingerprintTransformer()), + ("mol_2_fp", MACCSKeysFingerprintTransformer(handle_errors=True)), ("PCA", NanGuardWrapper(PCA(2), handle_errors=True)), ] ) From 50d9004b54c8a2127729519816f9fed2f9e5bde0 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 16:16:02 +0200 Subject: [PATCH 24/41] Updated standardizer and transformer for handling the errors. Still need to test it extensively. pytests test proceed, but only because error handling is default switched off. Also need to update docstrings and cleanup the files --- scikit_mol/descriptors.py | 58 ++++++++++++++++++----------- scikit_mol/standardizer.py | 75 +++++++++++++++++++++++--------------- scikit_mol/wrapper.py | 2 +- 3 files changed, 84 insertions(+), 51 deletions(-) diff --git a/scikit_mol/descriptors.py b/scikit_mol/descriptors.py index a516fc5..c461051 100644 --- a/scikit_mol/descriptors.py +++ b/scikit_mol/descriptors.py @@ -12,10 +12,9 @@ from scikit_mol.core import check_transform_input - class MolecularDescriptorTransformer(BaseEstimator, TransformerMixin): """Descriptor calculation transformer - + Parameters ---------- desc_list : (List of descriptor names) @@ -23,7 +22,7 @@ class MolecularDescriptorTransformer(BaseEstimator, TransformerMixin): parallel : boolean, int if True, multiprocessing will be used. If set to an int > 1, that specified number of processes will be used, otherwise it's autodetected. - start_method : str + start_method : str The method to start child processes when parallel=True. can be 'fork', 'spawn' or 'forkserver'. If None, the OS and Pythons default will be used. @@ -34,14 +33,18 @@ class MolecularDescriptorTransformer(BaseEstimator, TransformerMixin): """ + def __init__( - self, desc_list: Optional[str] = None, + self, + desc_list: Optional[str] = None, parallel: Union[bool, int] = False, - start_method: str = None#"fork" - ): + start_method: str = None, # "fork", + handle_errors: bool = False, + ): self.desc_list = desc_list self.parallel = parallel self.start_method = start_method + self.handle_errors = handle_errors def _get_desc_calculator(self) -> MolecularDescriptorCalculator: if self.desc_list: @@ -50,9 +53,7 @@ def _get_desc_calculator(self) -> MolecularDescriptorCalculator: for desc_name in self.desc_list if desc_name not in self.available_descriptors ] - assert ( - not unknown_descriptors - ), f"Unknown descriptor names {unknown_descriptors} specified, please check available_descriptors property\nPlease check availble list {self.available_descriptors}" + assert not unknown_descriptors, f"Unknown descriptor names {unknown_descriptors} specified, please check available_descriptors property\nPlease check availble list {self.available_descriptors}" else: self.desc_list = self.available_descriptors return MolecularDescriptorCalculator(self.desc_list) @@ -89,11 +90,21 @@ def start_method(self, start_method): """Allowed methods are spawn, fork and forkserver on MacOS and Linux, only spawn is possible on Windows. None will choose the default for the OS and version of Python.""" allowed_start_methods = ["spawn", "fork", "forkserver", None] - assert start_method in allowed_start_methods, f"start_method not in allowed methods {allowed_start_methods}" + assert ( + start_method in allowed_start_methods + ), f"start_method not in allowed methods {allowed_start_methods}" self._start_method = start_method def _transform_mol(self, mol: Mol) -> List[Any]: - return list(self.calculators.CalcDescriptors(mol)) + if not mol and self.handle_errors: + return [np.nan] * len(self.desc_list) + try: + return list(self.calculators.CalcDescriptors(mol)) + except Exception as e: + if self.handle_errors: + return [np.nan] * len(self.desc_list) + else: + raise e def fit(self, x, y=None): """Included for scikit-learn compatibility, does nothing""" @@ -119,19 +130,25 @@ def transform(self, x: List[Mol], y=None) -> np.ndarray: ------- np.array Descriptors, shape (samples, length of .selected_descriptors ) - + """ if not self.parallel: return self._transform(x) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes if n_processes is not None else multiprocessing.cpu_count() #TODO, tune the number of chunks per child process - + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes if n_processes is not None else multiprocessing.cpu_count() + ) # TODO, tune the number of chunks per child process + with get_context(self.start_method).Pool(processes=n_processes) as pool: params = self.get_params() - x_chunks = np.array_split(x, n_chunks) - #x_chunks = [x.reshape(-1, 1) for x in x_chunks] - arrays = pool.map(parallel_helper, [(params, x) for x in x_chunks]) #is the helper function a safer way of handling the picklind and child process communication + x_chunks = np.array_split(x, n_chunks) + # x_chunks = [x.reshape(-1, 1) for x in x_chunks] + arrays = pool.map( + parallel_helper, [(params, x) for x in x_chunks] + ) # is the helper function a safer way of handling the picklind and child process communication arr = np.concatenate(arrays) return arr @@ -139,12 +156,11 @@ def transform(self, x: List[Mol], y=None) -> np.ndarray: # May be safer to instantiate the transformer object in the child process, and only transfer the parameters # There were issues with freezing when using RDKit 2022.3 def parallel_helper(args): - """Will get a tuple with Desc2DTransformer parameters and mols to transform. + """Will get a tuple with Desc2DTransformer parameters and mols to transform. Will then instantiate the transformer and transform the molecules""" from scikit_mol.descriptors import MolecularDescriptorTransformer - + params, mols = args transformer = MolecularDescriptorTransformer(**params) y = transformer._transform(mols) return y - \ No newline at end of file diff --git a/scikit_mol/standardizer.py b/scikit_mol/standardizer.py index 5277f20..105616f 100644 --- a/scikit_mol/standardizer.py +++ b/scikit_mol/standardizer.py @@ -12,36 +12,39 @@ class Standardizer(BaseEstimator, TransformerMixin): - """ Input a list of rdkit mols, output the same list but standardised - """ + """Input a list of rdkit mols, output the same list but standardised""" + def __init__(self, neutralize=True, parallel=False): self.neutralize = neutralize self.parallel = parallel - self.start_method = None #TODO implement handling of start_method + self.start_method = None # TODO implement handling of start_method def fit(self, X, y=None): - return self - + return self + def _transform(self, X): - block = BlockLogs() # Block all RDkit logging + block = BlockLogs() # Block all RDkit logging arr = [] for mol in X: - # Normalizing functional groups - # https://molvs.readthedocs.io/en/latest/guide/standardize.html - clean_mol = rdMolStandardize.Cleanup(mol) - # Get parents fragments - parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol) - # Neutralise - if self.neutralize: - uncharger = rdMolStandardize.Uncharger() - uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) + if mol: # Falsy mols can't be processed, (e.g. if InvalidMol objects) + # Normalizing functional groups + # https://molvs.readthedocs.io/en/latest/guide/standardize.html + clean_mol = rdMolStandardize.Cleanup(mol) + # Get parents fragments + parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol) + # Neutralise + if self.neutralize: + uncharger = rdMolStandardize.Uncharger() + uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) + else: + uncharged_parent_clean_mol = parent_clean_mol + # Add to final list + arr.append(uncharged_parent_clean_mol) else: - uncharged_parent_clean_mol = parent_clean_mol - # Add to final list - arr.append(uncharged_parent_clean_mol) - - del block # Release logging block to previous state - return np.array(arr).reshape(-1,1) + arr.append(mol) + + del block # Release logging block to previous state + return np.array(arr).reshape(-1, 1) @feature_names_default_mol def get_feature_names_out(self, input_features=None): @@ -53,25 +56,39 @@ def transform(self, X, y=None): return self._transform(X) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes*2 if n_processes is not None else multiprocessing.cpu_count()*2 #TODO, tune the number of chunks per child process - - with multiprocessing.get_context(self.start_method).Pool(processes=n_processes) as pool: + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes * 2 + if n_processes is not None + else multiprocessing.cpu_count() * 2 + ) # TODO, tune the number of chunks per child process + + with multiprocessing.get_context(self.start_method).Pool( + processes=n_processes + ) as pool: x_chunks = np.array_split(X, n_chunks) - #TODO check what is fastest, pickle or recreate and do this only for classes that need this - #arrays = pool.map(self._transform, x_chunks) + # TODO check what is fastest, pickle or recreate and do this only for classes that need this + # arrays = pool.map(self._transform, x_chunks) parameters = self.get_params() - arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) + arrays = pool.map( + parallel_helper, + [ + (self.__class__.__name__, parameters, x_chunk) + for x_chunk in x_chunks + ], + ) arr = np.concatenate(arrays) return arr - def parallel_helper(args): """Parallel_helper takes a tuple with classname, the objects parameters and the mols to process. Then instantiates the class with the parameters and processes the mol. Intention is to be able to do this in chilcprocesses as some classes can't be pickled""" classname, parameters, X_mols = args from scikit_mol import standardizer + transformer = getattr(standardizer, classname)(**parameters) return transformer._transform(X_mols) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 13bda6c..5fbdca2 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -111,7 +111,7 @@ def wrapper(obj, X, *args, **kwargs): # If handle_errors is False, call the original function without filtering return func(obj, X, *args, **kwargs) - valid_mask = np.isfinite(X).all(axis=1) + valid_mask = np.isfinite(X).all(axis=1) # Find all rows with nan, inf, etc. if warn_on_invalid and not np.all(valid_mask): warnings.warn( From 555afafbab71812dfb3678aea61b4a5a0a6c37ce Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 16:23:47 +0200 Subject: [PATCH 25/41] Cleaning up.- We are getting closer --- scikit_mol/_invalid.py | 149 ----------------------------------------- scikit_mol/wrapper.py | 94 +------------------------- 2 files changed, 1 insertion(+), 242 deletions(-) delete mode 100644 scikit_mol/_invalid.py diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py deleted file mode 100644 index 7130216..0000000 --- a/scikit_mol/_invalid.py +++ /dev/null @@ -1,149 +0,0 @@ -from abc import ABC -from typing import Any, Callable, NamedTuple, Sequence, TypeVar - -import numpy as np -import numpy.typing as npt - -_T = TypeVar("_T") -_U = TypeVar("_U") - - -class InvalidInstance(NamedTuple): - """ - The InvalidInstance represents objects which raised an error during a pipeline step. - """ - - pipeline_step: str - error: str - - -class NumpyArrayWithInvalidInstances: - """ - The NumpyArrayWithInvalidInstances is - """ - - is_valid_array: npt.NDArray[np.bool_] - invalid_list: list[InvalidInstance] - value_array: npt.NDArray[Any] - - def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]): - self.is_valid_array = get_is_valid_array(array_list) - valid_vector_list = filter_by_list(array_list, self.is_valid_array) - self.value_array = np.vstack(valid_vector_list) - self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) - - def __len__(self): - return self.is_valid_array.shape[0] - - def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance: - n_invalids_prior = sum(~self.is_valid_array[: item - 1]) - if self.is_valid_array[item]: - return self.value_array[item - n_invalids_prior] - return self.invalid_list[n_invalids_prior + 1] - - def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> None: - n_invalids_prior = sum(~self.is_valid_array[: key - 1]) - if isinstance(value, InvalidInstance): - if self.is_valid_array[key]: - self.value_array = np.delete(self.value_array, key - n_invalids_prior) - self.is_valid_array[key] = False - self.invalid_list.insert(n_invalids_prior + 1, value) - else: - self.invalid_list[n_invalids_prior + 1] = value - else: - if self.is_valid_array[key]: - self.value_array[key - n_invalids_prior] = value - else: - self.value_array = np.insert( - self.value_array, key - n_invalids_prior, value - ) - del self.invalid_list[n_invalids_prior + 1] - self.is_valid_array[key] = True - - def array_filled_with(self, fill_value) -> npt.NDArray[Any]: - out = np.full((len(self.is_valid_array), self.value_array.shape[1]), fill_value) - out[self.is_valid_array] = self.value_array - return out - - -def batch_update_sequence( - old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances, - new_values: list[Any], - value_indices: npt.NDArray[np.int_], -): - old_list = list(old_list) # Make shallow copy of list to avoid inplace changes. - for new_value, idx in zip(new_values, value_indices, strict=True): - old_list[idx] = new_value - return old_list - - -def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]): - if isinstance(item_list, np.ndarray): - return item_list[is_valid_array] - - item_list_new = [] - for item, is_valid in zip(item_list, is_valid_array): - if is_valid: - item_list_new.append(item) - return item_list_new - - -# Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]] -# ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]] -def rdkit_error_handling(func): - def wrapper(obj, *args, **kwargs): - x = args[0] - if isinstance(x, NumpyArrayWithInvalidInstances): - is_valid_array = x.is_valid_array - x_sub = x.value_array - else: - is_valid_array = get_is_valid_array(x) - x_sub = filter_by_list(x, is_valid_array) - if len(args) > 1: - y = args[1] - if y is not None: - y_sub = filter_by_list(y, is_valid_array) - else: - y_sub = None - x_new = func(obj, x_sub, y_sub, **kwargs) - else: - x_new = func(obj, x_sub, **kwargs) - - if x_new is None: # fit may not return anything - return None - new_pos = np.where(is_valid_array)[0] - if isinstance(x_new, np.ndarray) and isinstance( - x, NumpyArrayWithInvalidInstances - ): - if x_new.shape[0] != x.value_array.shape[0]: - raise AssertionError("Numer of rows is not as expected.") - x.value_array = x_new - return x - if isinstance(x, (list, NumpyArrayWithInvalidInstances)): - x_list = batch_update_sequence(x, x_new, new_pos) - else: - x_array = np.array(x) - x_array[is_valid_array] = x_new - x_list = list(x_array) - if isinstance(x_new, NumpyArrayWithInvalidInstances): - return NumpyArrayWithInvalidInstances(x_list) - return x_list - - return wrapper - - -def filter_rows(X: Sequence[_T], y: Sequence[_U]) -> tuple[Sequence[_T], Sequence[_U]]: - is_valid_array = get_is_valid_array(X) - x_new = filter_by_list(X, is_valid_array) - y_new = filter_by_list(y, is_valid_array) - return x_new, y_new - - -def get_is_valid_array(item_list: Sequence[Any]) -> npt.NDArray[np.bool_]: - is_valid_list = [] - for i, item in enumerate(item_list): - if not isinstance(item, InvalidInstance): - is_valid_list.append(True) - else: - is_valid_list.append(False) - return np.array(is_valid_list, dtype=bool) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 5fbdca2..79c4c08 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -1,6 +1,5 @@ """Wrapper for sklearn estimators and pipelines to handle errors.""" -from abc import ABC from typing import Any import numpy as np @@ -8,100 +7,9 @@ from functools import wraps import warnings from sklearn.base import BaseEstimator -from sklearn.pipeline import Pipeline from sklearn.utils.metaestimators import available_if from sklearn.base import TransformerMixin -from scikit_mol._invalid import ( - rdkit_error_handling, - # InvalidMol, - NumpyArrayWithInvalidInstances, -) - - -class AbstractWrapper(BaseEstimator, ABC): - """ - Abstract class for the wrapper of sklearn objects. - - Attributes - ---------- - model: BaseEstimator | Pipeline - The wrapped model or pipeline. - """ - - model: BaseEstimator | Pipeline - - def __init__(self, replace_invalid: bool, replace_value: Any = np.nan): - """Initialize the AbstractWrapper. - - Parameters - ---------- - replace_invalid: bool - Whether to replace or remove errors - replace_value: Any, default=np.nan - If replace_invalid==True, insert this value on the erroneous instance. - """ - self.replace_invalid = replace_invalid - self.replace_value = replace_value - - @rdkit_error_handling - def fit(self, X, y, **fit_params) -> Any: - return self.model.fit(X, y, **fit_params) - - def has_predict(self) -> bool: - return hasattr(self.model, "predict") - - def has_fit_predict(self) -> bool: - return hasattr(self.model, "fit_predict") - - -class WrappedTransformer(AbstractWrapper): - """Wrapper for sklearn transformer objects.""" - - def __init__( - self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan - ): - """Initialize the WrappedTransformer. - - Parameters - ---------- - model: BaseEstimator - Wrapped model to be protected against Errors. - replace_invalid: bool - Whether to replace or remove errors - replace_value: Any, default=np.nan - If replace_invalid==True, insert this value on the erroneous instance. - """ - super().__init__(replace_invalid=replace_invalid, replace_value=replace_value) - self.model = model - - def has_transform(self) -> bool: - return hasattr(self.model, "transform") - - def has_fit_transform(self) -> bool: - return hasattr(self.model, "fit_transform") - - @available_if(has_transform) - @rdkit_error_handling - def transform(self, X): - return self.model.transform(X) - - @rdkit_error_handling - def _fit_transform(self, X, y): - return self.model.fit_transform(X, y) - - @available_if(has_fit_transform) - def fit_transform(self, X, y=None): - out = self._fit_transform(X, y) - if not self.replace_invalid: - return out - - if isinstance(out, NumpyArrayWithInvalidInstances): - return out.array_filled_with(self.replace_value) - - if isinstance(out, list): - return [self.replace_value if isinstance(v, InvalidMol) else v for v in out] - def filter_invalid_rows(fill_value=np.nan, warn_on_invalid=False): def decorator(func): @@ -214,6 +122,6 @@ def fit_transform(self, X, y): return self.estimator.fit_transform(X, y) @available_if(has_score) - @filter_invalid_rows() + @filter_invalid_rows(warn_on_invalid=True) def score(self, X, y): return self.estimator.score(X, y) From efbd8b9d7643d9768a2d73fdb3f41c7164acc135 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 27 Sep 2024 17:16:51 +0200 Subject: [PATCH 26/41] Fixed a bug in a test --- tests/test_sanitizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py index a69d597..f7193c2 100644 --- a/tests/test_sanitizer.py +++ b/tests/test_sanitizer.py @@ -32,7 +32,7 @@ def test_checksmilessanitation_x_and_y(smiles_list, invalid_smiles_list, sanitiz assert errors[0] == sanitizer.errors.SMILES[0] # Test that y is correctly split into y_error and the rest assert all([a == b for a, b in zip(y_sanitized, list(range(len(smiles_list) - 1)))]) - assert y_errors[0] == len(smiles_list) - 1 # Last smiles is invalid + assert y_errors[0] == len(invalid_smiles_list) - 1 # Last smiles is invalid def test_checksmilessanitation_np(smiles_list, invalid_smiles_list, sanitizer): From 407841129248614fd5df0c75fe1ba5728e5d9c31 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Thu, 3 Oct 2024 19:09:30 +0200 Subject: [PATCH 27/41] updating smiles_to_mol test case --- tests/test_smilestomol.py | 70 ++++++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/tests/test_smilestomol.py b/tests/test_smilestomol.py index e01af52..c951db2 100644 --- a/tests/test_smilestomol.py +++ b/tests/test_smilestomol.py @@ -7,32 +7,52 @@ import sklearn from scikit_mol.conversions import SmilesToMolTransformer from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT, DEFAULT_MOL_COLUMN_NAME -from fixtures import smiles_list, invalid_smiles_list, smiles_container, skip_pandas_output_test +from fixtures import ( + smiles_list, + invalid_smiles_list, + smiles_container, + skip_pandas_output_test, +) @pytest.fixture def smilestomol_transformer(): return SmilesToMolTransformer() + def test_smilestomol(smiles_container, smilestomol_transformer): - result_mols = smilestomol_transformer.transform(smiles_container) - result_smiles = [Chem.MolToSmiles(mol) for mol in result_mols.flatten()] - if isinstance(smiles_container, pd.DataFrame): - expected_smiles = smiles_container.iloc[:, 0].tolist() - else: - expected_smiles = smiles_container - assert all([ a == b for a, b in zip(expected_smiles, result_smiles)]) + result_mols = smilestomol_transformer.transform(smiles_container) + result_smiles = [Chem.MolToSmiles(mol) for mol in result_mols.flatten()] + if isinstance(smiles_container, pd.DataFrame): + expected_smiles = smiles_container.iloc[:, 0].tolist() + else: + expected_smiles = smiles_container + assert all([a == b for a, b in zip(expected_smiles, result_smiles)]) + + +def test_smilestomol_transform(smilestomol_transformer, smiles_container): + result = smilestomol_transformer.transform(smiles_container) + assert len(result) == len(smiles_container) + assert all(isinstance(mol, Chem.Mol) for mol in result.flatten()) + + +def test_smilestomol_fit(smilestomol_transformer, smiles_container): + result = smilestomol_transformer.fit(smiles_container) + assert result == smilestomol_transformer + def test_smilestomol_clone(smilestomol_transformer): t2 = clone(smilestomol_transformer) - params = smilestomol_transformer.get_params() + params = smilestomol_transformer.get_params() params_2 = t2.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) + def test_smilestomol_unsanitzable(invalid_smiles_list, smilestomol_transformer): with pytest.raises(ValueError): smilestomol_transformer.transform(invalid_smiles_list) + def test_descriptor_transformer_parallel(smiles_container, smilestomol_transformer): smilestomol_transformer.set_params(parallel=True) mol_list = smilestomol_transformer.transform(smiles_container) @@ -40,11 +60,31 @@ def test_descriptor_transformer_parallel(smiles_container, smilestomol_transform expected_smiles = smiles_container.iloc[:, 0].tolist() else: expected_smiles = smiles_container - assert all([ a == b for a, b in zip(expected_smiles, [Chem.MolToSmiles(mol) for mol in mol_list.flatten()])]) + assert all( + [ + a == b + for a, b in zip( + expected_smiles, [Chem.MolToSmiles(mol) for mol in mol_list.flatten()] + ) + ] + ) + + +def test_smilestomol_inverse_transform(smilestomol_transformer, smiles_container): + mols = smilestomol_transformer.transform(smiles_container) + result = smilestomol_transformer.inverse_transform(mols) + assert len(result) == len(smiles_container) + assert all(isinstance(smiles, str) for smiles in result.flatten()) + @skip_pandas_output_test def test_pandas_output(smiles_container, smilestomol_transformer, pandas_output): - mols = smilestomol_transformer.transform(smiles_container) - assert isinstance(mols, pd.DataFrame) - assert mols.shape[0] == len(smiles_container) - assert mols.columns.tolist() == [DEFAULT_MOL_COLUMN_NAME] \ No newline at end of file + mols = smilestomol_transformer.transform(smiles_container) + assert isinstance(mols, pd.DataFrame) + assert mols.shape[0] == len(smiles_container) + assert mols.columns.tolist() == [DEFAULT_MOL_COLUMN_NAME] + + +def test_smilestomol_get_feature_names_out(smilestomol_transformer): + feature_names = smilestomol_transformer.get_feature_names_out() + assert feature_names == [DEFAULT_MOL_COLUMN_NAME] From 63bfabe564da591014d9051d4fcfe7a227bb1047 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Thu, 3 Oct 2024 19:29:33 +0200 Subject: [PATCH 28/41] Updated test of smilestomol to check for the handle_errors capabilities, plus other expansions. --- tests/test_smilestomol.py | 89 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 6 deletions(-) diff --git a/tests/test_smilestomol.py b/tests/test_smilestomol.py index c951db2..737080d 100644 --- a/tests/test_smilestomol.py +++ b/tests/test_smilestomol.py @@ -6,7 +6,11 @@ from rdkit import Chem import sklearn from scikit_mol.conversions import SmilesToMolTransformer -from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT, DEFAULT_MOL_COLUMN_NAME +from scikit_mol.core import ( + SKLEARN_VERSION_PANDAS_OUT, + DEFAULT_MOL_COLUMN_NAME, + InvalidMol, +) from fixtures import ( smiles_list, invalid_smiles_list, @@ -77,14 +81,87 @@ def test_smilestomol_inverse_transform(smilestomol_transformer, smiles_container assert all(isinstance(smiles, str) for smiles in result.flatten()) +def test_smilestomol_inverse_transform_with_invalid( + invalid_smiles_list, smilestomol_transformer +): + smilestomol_transformer.set_params(handle_errors=True) + + # Forward transform + mols = smilestomol_transformer.transform(invalid_smiles_list) + + # Inverse transform + result = smilestomol_transformer.inverse_transform(mols) + + assert len(result) == len(invalid_smiles_list) + + # Check that all but the last element are the same as the original SMILES + for original, res in zip(invalid_smiles_list[:-1], result[:-1].flatten()): + assert isinstance(res, str) + assert original == res + + # Check that the last element is an InvalidMol instance + assert isinstance(result[-1].item(), InvalidMol) + assert "Invalid SMILES" in result[-1].item().error + assert invalid_smiles_list[-1] in result[-1].item().error + + +def test_smilestomol_get_feature_names_out(smilestomol_transformer): + feature_names = smilestomol_transformer.get_feature_names_out() + assert feature_names == [DEFAULT_MOL_COLUMN_NAME] + + +def test_smilestomol_handle_errors(invalid_smiles_list, smilestomol_transformer): + smilestomol_transformer.set_params(handle_errors=True) + result = smilestomol_transformer.transform(invalid_smiles_list) + + assert len(result) == len(invalid_smiles_list) + assert isinstance(result, np.ndarray) + + # Check that all but the last element are valid RDKit Mol objects + for mol in result[:-1].flatten(): + assert isinstance(mol, Chem.Mol) + assert mol is not None + + # Check that the last element is an InvalidMol instance + last_mol = result[-1].item() + assert isinstance(last_mol, InvalidMol) + + # Check if the error message is correctly set for the invalid SMILES + assert "Invalid SMILES" in last_mol.error + assert invalid_smiles_list[-1] in last_mol.error + + +@pytest.mark.skipif( + not skip_pandas_output_test, + reason="Pandas output not supported in this sklearn version", +) +def test_smilestomol_handle_errors_pandas_output( + invalid_smiles_list, smilestomol_transformer, pandas_output +): + smilestomol_transformer.set_params(handle_errors=True) + result = smilestomol_transformer.transform(invalid_smiles_list) + + assert len(result) == len(invalid_smiles_list) + assert isinstance(result, pd.DataFrame) + assert result.columns == [DEFAULT_MOL_COLUMN_NAME] + + # Check that all but the last element are valid RDKit Mol objects + for mol in result[DEFAULT_MOL_COLUMN_NAME][:-1]: + assert isinstance(mol, Chem.Mol) + assert mol is not None + + # Check that the last element is an InvalidMol instance + last_mol = result[DEFAULT_MOL_COLUMN_NAME].iloc[-1] + assert isinstance(last_mol, InvalidMol) + + # Check if the error message is correctly set for the invalid SMILES + assert "Invalid SMILES" in last_mol.error + assert invalid_smiles_list[-1] in last_mol.error + + @skip_pandas_output_test def test_pandas_output(smiles_container, smilestomol_transformer, pandas_output): mols = smilestomol_transformer.transform(smiles_container) assert isinstance(mols, pd.DataFrame) assert mols.shape[0] == len(smiles_container) assert mols.columns.tolist() == [DEFAULT_MOL_COLUMN_NAME] - - -def test_smilestomol_get_feature_names_out(smilestomol_transformer): - feature_names = smilestomol_transformer.get_feature_names_out() - assert feature_names == [DEFAULT_MOL_COLUMN_NAME] From 85ec6ca63743ce204a582a312e95a6b81ea9f187 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Thu, 3 Oct 2024 20:19:43 +0200 Subject: [PATCH 29/41] Developed a test of the error_handling for fingerprint transformers, still need to revisit the fixtures for invalid mols --- scikit_mol/fingerprints.py | 69 +++-- tests/fixtures.py | 15 +- tests/test_fptransformers.py | 505 ++++++++++++++++++++++++++--------- 3 files changed, 424 insertions(+), 165 deletions(-) diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 38569fb..de211dc 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -42,7 +42,7 @@ def __init__( # The dtype of the fingerprint array computed by the transformer # If needed this property can be overwritten in the child class. - _DTYPE_FINGERPRINT = np.int8 + _DTYPE_FINGERPRINT = float # Float is necessary for the handle_errors to work def _get_column_prefix(self) -> str: matched = _PATTERN_FINGERPRINT_TRANSFORMER.match(type(self).__name__) @@ -173,8 +173,6 @@ def transform(self, X, y=None): class MACCSKeysFingerprintTransformer(FpsTransformer): - _DTYPE_FINGERPRINT = float - def __init__(self, parallel: Union[bool, int] = False, handle_errors: bool = False): """MACCS keys fingerprinter calculates the 167 fixed MACCS keys @@ -307,36 +305,33 @@ def __init__( self.useCounts = useCounts def _mol2fp(self, mol): - if mol: - if self.useCounts: - return rdMolDescriptors.GetHashedAtomPairFingerprint( - mol, - nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId), - ) - else: - return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect( - mol, - nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - nBitsPerEntry=int(self.nBitsPerEntry), - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId), - ) + if self.useCounts: + return rdMolDescriptors.GetHashedAtomPairFingerprint( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) else: - return False + return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + nBitsPerEntry=int(self.nBitsPerEntry), + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) class TopologicalTorsionFingerprintTransformer(FpsTransformer): @@ -389,6 +384,11 @@ def _mol2fp(self, mol): class MHFingerprintTransformer(FpsTransformer): # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 + + _DTYPE_FINGERPRINT = ( + np.int32 + ) # MHFingerprints seemingly can't handle floats, so can't use handle_errors + def __init__( self, radius: int = 3, @@ -399,7 +399,6 @@ def __init__( n_permutations: int = 2048, seed: int = 42, parallel: Union[bool, int] = False, - handle_errors: bool = False, ): """Transforms the RDKit mol into the MinHash fingerprint (MHFP) @@ -413,7 +412,7 @@ def __init__( this is effectively the length of the FP seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, handle_errors=False) self.radius = radius self.rings = rings self.isomeric = isomeric @@ -438,8 +437,6 @@ def __setstate__(self, state): # Re-create the unpicklable property self._recreate_encoder() - _DTYPE_FINGERPRINT = np.int32 - def _mol2fp(self, mol): fp = self.mhfp_encoder.EncodeMol( mol, self.radius, self.rings, self.isomeric, self.kekulize, self.min_radius diff --git a/tests/fixtures.py b/tests/fixtures.py index 57bcf60..72cf57e 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -23,7 +23,11 @@ from scikit_mol.descriptors import MolecularDescriptorTransformer from scikit_mol.conversions import SmilesToMolTransformer from scikit_mol.standardizer import Standardizer -from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT, DEFAULT_MOL_COLUMN_NAME +from scikit_mol.core import ( + SKLEARN_VERSION_PANDAS_OUT, + DEFAULT_MOL_COLUMN_NAME, + InvalidMol, +) # TODO these should really go into the conftest.py, so that they are automatically imported in the tests @@ -181,3 +185,12 @@ def combined_transformer(featurizer): remainder="drop", ) return transformer + + +@pytest.fixture +def mols_with_invalid_container(): + valid_smiles = ["CC", "CCO", "c1ccccc1"] + invalid_smiles = "NOT_A_VALID_SMILES" + mols = [Chem.MolFromSmiles(s) for s in valid_smiles] + mols.append(InvalidMol("TestError", "Invalid SMILES")) + return mols diff --git a/tests/test_fptransformers.py b/tests/test_fptransformers.py index d149f3a..05cb536 100644 --- a/tests/test_fptransformers.py +++ b/tests/test_fptransformers.py @@ -4,153 +4,275 @@ import numpy as np import pandas as pd from rdkit import Chem -from fixtures import mols_list, smiles_list, mols_container, smiles_container, fingerprint, chiral_smiles_list, chiral_mols_list +from fixtures import ( + mols_list, + smiles_list, + mols_container, + smiles_container, + fingerprint, + chiral_smiles_list, + chiral_mols_list, + mols_with_invalid_container, +) from sklearn import clone -from scikit_mol.fingerprints import MorganFingerprintTransformer, MACCSKeysFingerprintTransformer, RDKitFingerprintTransformer, AtomPairFingerprintTransformer, TopologicalTorsionFingerprintTransformer, SECFingerprintTransformer, MHFingerprintTransformer, AvalonFingerprintTransformer - +from scikit_mol.fingerprints import ( + MorganFingerprintTransformer, + MACCSKeysFingerprintTransformer, + RDKitFingerprintTransformer, + AtomPairFingerprintTransformer, + TopologicalTorsionFingerprintTransformer, + SECFingerprintTransformer, + MHFingerprintTransformer, + AvalonFingerprintTransformer, +) @pytest.fixture def morgan_transformer(): return MorganFingerprintTransformer() + @pytest.fixture def rdkit_transformer(): return RDKitFingerprintTransformer() + @pytest.fixture def atompair_transformer(): return AtomPairFingerprintTransformer() + @pytest.fixture def topologicaltorsion_transformer(): return TopologicalTorsionFingerprintTransformer() + @pytest.fixture def maccs_transformer(): return MACCSKeysFingerprintTransformer() + @pytest.fixture def secfp_transformer(): return SECFingerprintTransformer() + @pytest.fixture def mhfp_transformer(): return MHFingerprintTransformer() + @pytest.fixture def avalon_transformer(): return AvalonFingerprintTransformer() + def test_fpstransformer_fp2array(morgan_transformer, fingerprint): fp = morgan_transformer._fp2array(fingerprint) - #See that fp is the correct type, shape and bit count - assert(type(fp) == type(np.array([0]))) - assert(fp.shape == (1000,)) - assert(fp.sum() == 25) + # See that fp is the correct type, shape and bit count + assert type(fp) == type(np.array([0])) + assert fp.shape == (1000,) + assert fp.sum() == 25 + def test_fpstransformer_transform_mol(morgan_transformer, mols_list): fp = morgan_transformer._transform_mol(mols_list[0]) - #See that fp is the correct type, shape and bit count - assert(type(fp) == type(np.array([0]))) - assert(fp.shape == (2048,)) - assert(fp.sum() == 14) - -def test_clonability(maccs_transformer, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - for t in [maccs_transformer, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, secfp_transformer, mhfp_transformer, avalon_transformer]: - params = t.get_params() + # See that fp is the correct type, shape and bit count + assert type(fp) == type(np.array([0])) + assert fp.shape == (2048,) + assert fp.sum() == 14 + + +def test_clonability( + maccs_transformer, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + for t in [ + maccs_transformer, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, + ]: + params = t.get_params() t2 = clone(t) params_2 = t2.get_params() - #Parameters of cloned transformers should be the same - assert all([ params[key] == params_2[key] for key in params.keys()]) - #Cloned transformers should not be the same object + # Parameters of cloned transformers should be the same + assert all([params[key] == params_2[key] for key in params.keys()]) + # Cloned transformers should not be the same object assert t2 != t -def test_set_params(morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, avalon_transformer]: - params = t.get_params() - #change extracted dictionary - params['nBits'] = 4242 - #change params in transformer - t.set_params(nBits = 4242) + +def test_set_params( + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + avalon_transformer, + ]: + params = t.get_params() + # change extracted dictionary + params["nBits"] = 4242 + # change params in transformer + t.set_params(nBits=4242) # get parameters as dictionary and assert that it is the same params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) for t in [rdkit_transformer]: - params = t.get_params() - params['fpSize'] = 4242 - t.set_params(fpSize = 4242) + params = t.get_params() + params["fpSize"] = 4242 + t.set_params(fpSize=4242) params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) for t in [secfp_transformer]: - params = t.get_params() - params['length'] = 4242 - t.set_params(length = 4242) + params = t.get_params() + params["length"] = 4242 + t.set_params(length=4242) params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) for t in [mhfp_transformer]: - params = t.get_params() - params['n_permutations'] = 4242 - t.set_params(n_permutations = 4242) + params = t.get_params() + params["n_permutations"] = 4242 + t.set_params(n_permutations=4242) params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) - -def test_transform(mols_container, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - #Test the different transformers - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, rdkit_transformer, secfp_transformer, mhfp_transformer, avalon_transformer]: - params = t.get_params() + assert all([params[key] == params_2[key] for key in params.keys()]) + + +def test_transform( + mols_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + # Test the different transformers + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, + ]: + params = t.get_params() fps = t.transform(mols_container) - #Assert that the same length of input and output + # Assert that the same length of input and output assert len(fps) == len(mols_container) # assert that the size of the fingerprint is the expected size - if type(t) == type(maccs_transformer) or type(t) == type(secfp_transformer) or type(t) == type(mhfp_transformer): + if ( + type(t) == type(maccs_transformer) + or type(t) == type(secfp_transformer) + or type(t) == type(mhfp_transformer) + ): fpsize = t.nBits elif type(t) == type(rdkit_transformer): - fpsize = params['fpSize'] + fpsize = params["fpSize"] else: - fpsize = params['nBits'] - + fpsize = params["nBits"] + assert len(fps[0]) == fpsize -def test_transform_parallel(mols_container, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - #Test the different transformers - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, rdkit_transformer, secfp_transformer, mhfp_transformer, avalon_transformer]: + +def test_transform_parallel( + mols_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + # Test the different transformers + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, + ]: t.set_params(parallel=True) - params = t.get_params() + params = t.get_params() fps = t.transform(mols_container) - #Assert that the same length of input and output + # Assert that the same length of input and output assert len(fps) == len(mols_container) # assert that the size of the fingerprint is the expected size - if type(t) == type(maccs_transformer) or type(t) == type(secfp_transformer) or type(t) == type(mhfp_transformer): + if ( + type(t) == type(maccs_transformer) + or type(t) == type(secfp_transformer) + or type(t) == type(mhfp_transformer) + ): fpsize = t.nBits elif type(t) == type(rdkit_transformer): - fpsize = params['fpSize'] + fpsize = params["fpSize"] else: - fpsize = params['nBits'] - + fpsize = params["nBits"] + assert len(fps[0]) == fpsize -def test_picklable(morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, secfp_transformer, avalon_transformer): - #Test the different transformers - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, rdkit_transformer, secfp_transformer, avalon_transformer]: +def test_picklable( + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, +): + # Test the different transformers + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: with tempfile.NamedTemporaryFile() as f: pickle.dump(t, f) f.seek(0) t2 = pickle.load(f) - assert(t.get_params() == t2.get_params()) - + assert t.get_params() == t2.get_params() + def assert_transformer_set_params(tr_class, new_params, mols_list): default_params = tr_class().get_params() for key in new_params.keys(): - tr = tr_class() params = tr.get_params() params[key] = new_params[key] @@ -164,20 +286,36 @@ def assert_transformer_set_params(tr_class, new_params, mols_list): fps_init_new_params = new_tr.transform(mols_list) # Now fp_default should not be the same as fp_reset_params - assert ~np.all([np.array_equal(fp_default, fp_reset_params) for fp_default, fp_reset_params in zip(fps_default, fps_reset_params)]), f"Assertation error, FP appears the same, although the {key} should be changed from {default_params[key]} to {params[key]}" + assert ~np.all( + [ + np.array_equal(fp_default, fp_reset_params) + for fp_default, fp_reset_params in zip(fps_default, fps_reset_params) + ] + ), f"Assertation error, FP appears the same, although the {key} should be changed from {default_params[key]} to {params[key]}" # fp_reset_params and fp_init_new_params should however be the same - assert np.all([np.array_equal(fp_init_new_params, fp_reset_params) for fp_init_new_params, fp_reset_params in zip(fps_init_new_params, fps_reset_params)]) , f"Assertation error, FP appears to be different, although the {key} should be changed back as well as initialized to {params[key]}" + assert np.all( + [ + np.array_equal(fp_init_new_params, fp_reset_params) + for fp_init_new_params, fp_reset_params in zip( + fps_init_new_params, fps_reset_params + ) + ] + ), f"Assertation error, FP appears to be different, although the {key} should be changed back as well as initialized to {params[key]}" def test_morgan_set_params(chiral_mols_list): - new_params = {'nBits': 1024, - 'radius': 1, - 'useBondTypes': False,# TODO, why doesn't this change the FP? - 'useChirality': True, - 'useCounts': True, - 'useFeatures': True} - - assert_transformer_set_params(MorganFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "nBits": 1024, + "radius": 1, + "useBondTypes": False, # TODO, why doesn't this change the FP? + "useChirality": True, + "useCounts": True, + "useFeatures": True, + } + + assert_transformer_set_params( + MorganFingerprintTransformer, new_params, chiral_mols_list + ) def test_atompairs_set_params(chiral_mols_list): @@ -186,71 +324,182 @@ def test_atompairs_set_params(chiral_mols_list): #'confId': -1, #'fromAtoms': 1, #'ignoreAtoms': 0, - 'includeChirality': True, - 'maxLength': 3, - 'minLength': 3, - 'nBits': 1024, - 'nBitsPerEntry': 3, + "includeChirality": True, + "maxLength": 3, + "minLength": 3, + "nBits": 1024, + "nBitsPerEntry": 3, #'use2D': True, #TODO, understand why this can't be set different - 'useCounts': True} - - assert_transformer_set_params(AtomPairFingerprintTransformer, new_params, chiral_mols_list) + "useCounts": True, + } + + assert_transformer_set_params( + AtomPairFingerprintTransformer, new_params, chiral_mols_list + ) def test_topologicaltorsion_set_params(chiral_mols_list): - new_params = {#'atomInvariants': 0, - #'fromAtoms': 0, - #'ignoreAtoms': 0, - #'includeChirality': True, #TODO, figure out why this setting seems to give same FP wheter toggled or not - 'nBits': 1024, - 'nBitsPerEntry': 3, - 'targetSize': 5, - 'useCounts': True} - - assert_transformer_set_params(TopologicalTorsionFingerprintTransformer, new_params, chiral_mols_list) + new_params = { #'atomInvariants': 0, + #'fromAtoms': 0, + #'ignoreAtoms': 0, + #'includeChirality': True, #TODO, figure out why this setting seems to give same FP wheter toggled or not + "nBits": 1024, + "nBitsPerEntry": 3, + "targetSize": 5, + "useCounts": True, + } + + assert_transformer_set_params( + TopologicalTorsionFingerprintTransformer, new_params, chiral_mols_list + ) + def test_RDKitFPTransformer(chiral_mols_list): - new_params = {#'atomInvariantsGenerator': None, - #'branchedPaths': False, - #'countBounds': 0, #TODO: What does this do? - 'countSimulation': True, - 'fpSize': 1024, - 'maxPath': 3, - 'minPath': 2, - 'numBitsPerFeature': 3, - 'useBondOrder': False, #TODO, why doesn't this change the FP? - #'useHs': False, #TODO, why doesn't this change the FP? - } - assert_transformer_set_params(RDKitFingerprintTransformer, new_params, chiral_mols_list) + new_params = { #'atomInvariantsGenerator': None, + #'branchedPaths': False, + #'countBounds': 0, #TODO: What does this do? + "countSimulation": True, + "fpSize": 1024, + "maxPath": 3, + "minPath": 2, + "numBitsPerFeature": 3, + "useBondOrder": False, # TODO, why doesn't this change the FP? + #'useHs': False, #TODO, why doesn't this change the FP? + } + assert_transformer_set_params( + RDKitFingerprintTransformer, new_params, chiral_mols_list + ) def test_SECFingerprintTransformer(chiral_mols_list): - new_params = {'isomeric': True, - 'kekulize': True, - 'length': 1048, - 'min_radius': 2, - #'n_permutations': 2, # The SECFp is not using this setting - 'radius': 2, - 'rings': False, - #'seed': 1 # The SECFp is not using this setting - } - assert_transformer_set_params(SECFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "isomeric": True, + "kekulize": True, + "length": 1048, + "min_radius": 2, + #'n_permutations': 2, # The SECFp is not using this setting + "radius": 2, + "rings": False, + #'seed': 1 # The SECFp is not using this setting + } + assert_transformer_set_params( + SECFingerprintTransformer, new_params, chiral_mols_list + ) + def test_MHFingerprintTransformer(chiral_mols_list): - new_params = {'radius': 2, - 'rings': False, - 'isomeric': True, - 'kekulize': True, - 'min_radius': 2, - 'n_permutations': 4096, - 'seed': 44 - } - assert_transformer_set_params(MHFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "radius": 2, + "rings": False, + "isomeric": True, + "kekulize": True, + "min_radius": 2, + "n_permutations": 4096, + "seed": 44, + } + assert_transformer_set_params( + MHFingerprintTransformer, new_params, chiral_mols_list + ) + def test_AvalonFingerprintTransformer(chiral_mols_list): - new_params = {'nBits': 1024, - 'isQuery': True, - # 'resetVect': True, #TODO: this doesn't change the FP - 'bitFlags': 32767 - } - assert_transformer_set_params(AvalonFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "nBits": 1024, + "isQuery": True, + # 'resetVect': True, #TODO: this doesn't change the FP + "bitFlags": 32767, + } + assert_transformer_set_params( + AvalonFingerprintTransformer, new_params, chiral_mols_list + ) + + +def test_transform_with_error_handling( + mols_with_invalid_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: + t.set_params(handle_errors=True) + print(type(t)) + fps = t.transform(mols_with_invalid_container) + + assert len(fps) == len(mols_with_invalid_container) + + # Check that the last row (corresponding to the InvalidMol) contains NaNs + assert np.all(np.isnan(fps[-1])) + + # Check that other rows don't contain NaNs + assert not np.any(np.isnan(fps[:-1])) + + +def test_transform_without_error_handling( + mols_with_invalid_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: + t.set_params(handle_errors=False) + with pytest.raises( + Exception + ): # You might want to be more specific about the exception type + print(f"testing {type(t)}") + t.transform(mols_with_invalid_container) + + +# Add this test to check parallel processing with error handling +def test_transform_parallel_with_error_handling( + mols_with_invalid_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: + t.set_params(handle_errors=True, parallel=True) + fps = t.transform(mols_with_invalid_container) + + assert len(fps) == len(mols_with_invalid_container) + + # Check that the last row (corresponding to the InvalidMol) contains NaNs + assert np.all(np.isnan(fps[-1])) + + # Check that other rows don't contain NaNs + assert not np.any(np.isnan(fps[:-1])) From 70a0598c3a546b7a244b621c0cfdb95b76bd372c Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Thu, 3 Oct 2024 20:33:07 +0200 Subject: [PATCH 30/41] Added test of the fingerprint classes for error handling --- tests/fixtures.py | 21 ++++++++++++--------- tests/test_fptransformers.py | 1 + 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index 72cf57e..2b5a2e6 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -114,6 +114,18 @@ def chiral_mols_list(chiral_smiles_list): return [Chem.MolFromSmiles(smiles) for smiles in chiral_smiles_list] +@pytest.fixture +def mols_with_invalid_container(invalid_smiles_list): + mols = [] + for smiles in invalid_smiles_list: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + mols.append(InvalidMol("TestError", f"Invalid SMILES: {smiles}")) + else: + mols.append(mol) + return mols + + @pytest.fixture def fingerprint(mols_list): return rdMolDescriptors.GetHashedMorganFingerprint(mols_list[0], 2, nBits=1000) @@ -185,12 +197,3 @@ def combined_transformer(featurizer): remainder="drop", ) return transformer - - -@pytest.fixture -def mols_with_invalid_container(): - valid_smiles = ["CC", "CCO", "c1ccccc1"] - invalid_smiles = "NOT_A_VALID_SMILES" - mols = [Chem.MolFromSmiles(s) for s in valid_smiles] - mols.append(InvalidMol("TestError", "Invalid SMILES")) - return mols diff --git a/tests/test_fptransformers.py b/tests/test_fptransformers.py index 05cb536..290e747 100644 --- a/tests/test_fptransformers.py +++ b/tests/test_fptransformers.py @@ -13,6 +13,7 @@ chiral_smiles_list, chiral_mols_list, mols_with_invalid_container, + invalid_smiles_list, ) from sklearn import clone From e1b2557a02f3169aac58e613a267c500f3ba9e98 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Wed, 9 Oct 2024 20:03:46 +0200 Subject: [PATCH 31/41] Changed name to safe_inference mode consistently and fixed the pytests for the fingerprints (which highlighted a bug in the parallel code that was fixed). --- scikit_mol/conversions.py | 30 +++++++-- scikit_mol/descriptors.py | 11 ++-- scikit_mol/fingerprints.py | 77 ++++++++++++---------- scikit_mol/standardizer.py | 57 ++++++++-------- scikit_mol/utilities.py | 32 ++++----- scikit_mol/wrapper.py | 123 ++++++++++++++++++++++++++++++----- tests/test_fptransformers.py | 28 ++++---- 7 files changed, 244 insertions(+), 114 deletions(-) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index c18d1bb..0f5b11d 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -17,10 +17,28 @@ class SmilesToMolTransformer(BaseEstimator, TransformerMixin): - def __init__(self, parallel: Union[bool, int] = False, handle_errors: bool = False): + """ + Transformer for converting SMILES strings to RDKit mol objects. + + This transformer can be included in pipelines during development and training, + but the safe inference mode should only be enabled when deploying models for + inference in production environments. + + Parameters: + ----------- + parallel : Union[bool, int], default=False + If True or int > 1, enables parallel processing. + safe_inference_mode : bool, default=False + If True, enables safeguards for handling invalid data during inference. + This should only be set to True when deploying models to production. + """ + + def __init__( + self, parallel: Union[bool, int] = False, safe_inference_mode: bool = False + ): self.parallel = parallel self.start_method = None # TODO implement handling of start_method - self.handle_errors = handle_errors + self.safe_inference_mode = safe_inference_mode @feature_names_default_mol def get_feature_names_out(self, input_features=None): @@ -46,7 +64,7 @@ def transform(self, X_smiles_list, y=None): Raises ------ ValueError - Raises ValueError if a SMILES string is unparsable by RDKit + Raises ValueError if a SMILES string is unparsable by RDKit and safe_inference_mode is False """ if not self.parallel: @@ -84,11 +102,11 @@ def _transform(self, X): else: message = f"Invalid SMILES: {smiles}" X_out.append(InvalidMol(str(self), message)) - if not self.handle_errors and not all(X_out): + if not self.safe_inference_mode and not all(X_out): fails = [x for x in X_out if not x] raise ValueError( f"Invalid SMILES found: {fails}." - ) # TODO with this appraoch we get all errors, but we do process ALL the smiles first which could be slow + ) # TODO with this approach we get all errors, but we do process ALL the smiles first which could be slow return np.array(X_out).reshape(-1, 1) @check_transform_input @@ -109,7 +127,7 @@ def inverse_transform(self, X_mols_list, y=None): else: X_out.append(InvalidMol(str(self), f"Not a Mol: {mol}")) - if not self.handle_errors and not all(isinstance(x, str) for x in X_out): + if not self.safe_inference_mode and not all(isinstance(x, str) for x in X_out): fails = [x for x in X_out if not isinstance(x, str)] raise ValueError(f"Invalid Mols found: {fails}.") diff --git a/scikit_mol/descriptors.py b/scikit_mol/descriptors.py index c461051..2a4c79c 100644 --- a/scikit_mol/descriptors.py +++ b/scikit_mol/descriptors.py @@ -25,6 +25,9 @@ class MolecularDescriptorTransformer(BaseEstimator, TransformerMixin): start_method : str The method to start child processes when parallel=True. can be 'fork', 'spawn' or 'forkserver'. If None, the OS and Pythons default will be used. + safe_inference_mode : bool + If True, enables safeguards for handling invalid data during inference. + This should only be set to True when deploying models to production. Returns ------- @@ -39,12 +42,12 @@ def __init__( desc_list: Optional[str] = None, parallel: Union[bool, int] = False, start_method: str = None, # "fork", - handle_errors: bool = False, + safe_inference_mode: bool = False, ): self.desc_list = desc_list self.parallel = parallel self.start_method = start_method - self.handle_errors = handle_errors + self.safe_inference_mode = safe_inference_mode def _get_desc_calculator(self) -> MolecularDescriptorCalculator: if self.desc_list: @@ -96,12 +99,12 @@ def start_method(self, start_method): self._start_method = start_method def _transform_mol(self, mol: Mol) -> List[Any]: - if not mol and self.handle_errors: + if not mol and self.safe_inference_mode: return [np.nan] * len(self.desc_list) try: return list(self.calculators.CalcDescriptors(mol)) except Exception as e: - if self.handle_errors: + if self.safe_inference_mode: return [np.nan] * len(self.desc_list) else: raise e diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index de211dc..86274f6 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -1,4 +1,3 @@ -# %% from multiprocessing import Pool, get_context import multiprocessing import re @@ -28,21 +27,20 @@ ) -# %% class FpsTransformer(ABC, BaseEstimator, TransformerMixin): def __init__( self, parallel: Union[bool, int] = False, start_method: str = None, - handle_errors: bool = False, + safe_inference_mode: bool = False, ): self.parallel = parallel self.start_method = start_method - self.handle_errors = handle_errors + self.safe_inference_mode = safe_inference_mode # The dtype of the fingerprint array computed by the transformer # If needed this property can be overwritten in the child class. - _DTYPE_FINGERPRINT = float # Float is necessary for the handle_errors to work + _DTYPE_FINGERPRINT = np.int8 def _get_column_prefix(self) -> str: matched = _PATTERN_FINGERPRINT_TRANSFORMER.match(type(self).__name__) @@ -86,21 +84,21 @@ def _mol2fp(self, mol): raise NotImplementedError("_mol2fp not implemented") def _fp2array(self, fp): - arr = np.zeros((self.nBits,), dtype=self._DTYPE_FINGERPRINT) if fp: + arr = np.zeros((self.nBits,), dtype=self._DTYPE_FINGERPRINT) DataStructs.ConvertToNumpyArray(fp, arr) + return arr else: - arr[:] = np.nan # Sadly, dtype=int8 does not allow for NaN values - return arr + return np.ma.masked_all((self.nBits,), dtype=self._DTYPE_FINGERPRINT) def _transform_mol(self, mol): - if not mol and self.handle_errors: + if not mol and self.safe_inference_mode: return self._fp2array(False) try: fp = self._mol2fp(mol) return self._fp2array(fp) except Exception as e: - if self.handle_errors: + if self.safe_inference_mode: return self._fp2array(False) else: raise e @@ -114,10 +112,16 @@ def fit(self, X, y=None): @check_transform_input def _transform(self, X): - arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) - for i, mol in enumerate(X): - arr[i, :] = self._transform_mol(mol) - return arr + if self.safe_inference_mode: + # Use the new method with masked arrays if we're in safe inference mode + arrays = [self._transform_mol(mol) for mol in X] + return np.ma.stack(arrays) + else: + # Use the original, faster method if we're not in safe inference mode + arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) + for i, mol in enumerate(X): + arr[i, :] = self._transform_mol(mol) + return arr def _transform_sparse(self, X): arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) @@ -167,17 +171,21 @@ def transform(self, X, y=None): for x_chunk in x_chunks ], ) - - arr = np.concatenate(arrays) + if self.safe_inference_mode: + arr = np.ma.concatenate(arrays) + else: + arr = np.concatenate(arrays) return arr class MACCSKeysFingerprintTransformer(FpsTransformer): - def __init__(self, parallel: Union[bool, int] = False, handle_errors: bool = False): + def __init__( + self, parallel: Union[bool, int] = False, safe_inference_mode: bool = False + ): """MACCS keys fingerprinter calculates the 167 fixed MACCS keys """ - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.nBits = 167 @property @@ -210,7 +218,7 @@ def __init__( numBitsPerFeature: int = 2, atomInvariantsGenerator=None, parallel: Union[bool, int] = False, - handle_errors: bool = False, + safe_inference_mode: bool = False, ): """Calculates the RDKit fingerprints @@ -237,7 +245,7 @@ def __init__( atomInvariantsGenerator : _type_, optional atom invariants to be used during fingerprint generation, by default None """ - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.minPath = minPath self.maxPath = maxPath self.useHs = useHs @@ -289,9 +297,9 @@ def __init__( nBits=2048, useCounts: bool = False, parallel: Union[bool, int] = False, - handle_errors: bool = False, + safe_inference_mode: bool = False, ): - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.minLength = minLength self.maxLength = maxLength self.fromAtoms = fromAtoms @@ -346,9 +354,9 @@ def __init__( nBits=2048, useCounts: bool = False, parallel: Union[bool, int] = False, - handle_errors: bool = False, + safe_inference_mode: bool = False, ): - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.targetSize = targetSize self.fromAtoms = fromAtoms self.ignoreAtoms = ignoreAtoms @@ -385,9 +393,7 @@ def _mol2fp(self, mol): class MHFingerprintTransformer(FpsTransformer): # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 - _DTYPE_FINGERPRINT = ( - np.int32 - ) # MHFingerprints seemingly can't handle floats, so can't use handle_errors + _DTYPE_FINGERPRINT = np.int32 def __init__( self, @@ -399,6 +405,7 @@ def __init__( n_permutations: int = 2048, seed: int = 42, parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, ): """Transforms the RDKit mol into the MinHash fingerprint (MHFP) @@ -412,7 +419,7 @@ def __init__( this is effectively the length of the FP seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel=parallel, handle_errors=False) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.radius = radius self.rings = rings self.isomeric = isomeric @@ -490,7 +497,7 @@ def __init__( n_permutations: int = 0, seed: int = 0, parallel: Union[bool, int] = False, - handle_errors: bool = False, + safe_inference_mode: bool = False, ): """Transforms the RDKit mol into the SMILES extended connectivity fingerprint (SECFP) @@ -504,7 +511,7 @@ def __init__( n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0. seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.radius = radius self.rings = rings self.isomeric = isomeric @@ -582,7 +589,7 @@ def __init__( useFeatures=False, useCounts=False, parallel: Union[bool, int] = False, - handle_errors: bool = False, + safe_inference_mode: bool = False, ): """Transform RDKit mols into Count or bit-based hashed MorganFingerprints @@ -601,7 +608,7 @@ def __init__( useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.nBits = nBits self.radius = radius self.useChirality = useChirality @@ -640,7 +647,7 @@ def __init__( bitFlags: int = 15761407, useCounts: bool = False, parallel: Union[bool, int] = False, - handle_errors: bool = False, + safe_inference_mode: bool = False, ): """Transform RDKit mols into Count or bit-based Avalon Fingerprints @@ -657,7 +664,7 @@ def __init__( useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel=parallel, handle_errors=handle_errors) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.nBits = nBits self.isQuery = isQuery self.resetVect = resetVect @@ -685,7 +692,7 @@ def _mol2fp(self, mol): def parallel_helper(args): """Parallel_helper takes a tuple with classname, the objects parameters and the mols to process. Then instantiates the class with the parameters and processes the mol. - Intention is to be able to do this in chilcprocesses as some classes can't be pickled""" + Intention is to be able to do this in child processes as some classes can't be pickled""" classname, parameters, X_mols = args from scikit_mol import fingerprints diff --git a/scikit_mol/standardizer.py b/scikit_mol/standardizer.py index 105616f..29a7722 100644 --- a/scikit_mol/standardizer.py +++ b/scikit_mol/standardizer.py @@ -8,43 +8,50 @@ from rdkit.rdBase import BlockLogs import numpy as np -from scikit_mol.core import check_transform_input, feature_names_default_mol +from scikit_mol.core import check_transform_input, feature_names_default_mol, InvalidMol class Standardizer(BaseEstimator, TransformerMixin): """Input a list of rdkit mols, output the same list but standardised""" - def __init__(self, neutralize=True, parallel=False): + def __init__(self, neutralize=True, parallel=False, safe_inference_mode=False): self.neutralize = neutralize self.parallel = parallel self.start_method = None # TODO implement handling of start_method + self.safe_inference_mode = safe_inference_mode def fit(self, X, y=None): return self - def _transform(self, X): - block = BlockLogs() # Block all RDkit logging - arr = [] - for mol in X: - if mol: # Falsy mols can't be processed, (e.g. if InvalidMol objects) - # Normalizing functional groups - # https://molvs.readthedocs.io/en/latest/guide/standardize.html - clean_mol = rdMolStandardize.Cleanup(mol) - # Get parents fragments - parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol) - # Neutralise - if self.neutralize: - uncharger = rdMolStandardize.Uncharger() - uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) - else: - uncharged_parent_clean_mol = parent_clean_mol - # Add to final list - arr.append(uncharged_parent_clean_mol) + def _standardize_mol(self, mol): + if not mol: + if self.safe_inference_mode: + return InvalidMol(str(self), "Invalid input molecule") else: - arr.append(mol) + raise ValueError("Invalid input molecule") - del block # Release logging block to previous state - return np.array(arr).reshape(-1, 1) + try: + block = BlockLogs() # Block all RDkit logging + # Normalizing functional groups + clean_mol = rdMolStandardize.Cleanup(mol) + # Get parents fragments + parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol) + # Neutralise + if self.neutralize: + uncharger = rdMolStandardize.Uncharger() + uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) + else: + uncharged_parent_clean_mol = parent_clean_mol + del block # Release logging block to previous state + return uncharged_parent_clean_mol + except Exception as e: + if self.safe_inference_mode: + return InvalidMol(str(self), f"Standardization failed: {str(e)}") + else: + raise + + def _transform(self, X): + return np.array([self._standardize_mol(mol) for mol in X]).reshape(-1, 1) @feature_names_default_mol def get_feature_names_out(self, input_features=None): @@ -69,8 +76,6 @@ def transform(self, X, y=None): processes=n_processes ) as pool: x_chunks = np.array_split(X, n_chunks) - # TODO check what is fastest, pickle or recreate and do this only for classes that need this - # arrays = pool.map(self._transform, x_chunks) parameters = self.get_params() arrays = pool.map( parallel_helper, @@ -86,7 +91,7 @@ def transform(self, X, y=None): def parallel_helper(args): """Parallel_helper takes a tuple with classname, the objects parameters and the mols to process. Then instantiates the class with the parameters and processes the mol. - Intention is to be able to do this in chilcprocesses as some classes can't be pickled""" + Intention is to be able to do this in child processes as some classes can't be pickled""" classname, parameters, X_mols = args from scikit_mol import standardizer diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 481116a..13c360e 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -66,58 +66,58 @@ def sanitize(self, X_smiles_list, y=None): return X_out, X_errors -def set_handle_errors(estimator, value): +def set_safe_inference_mode(estimator, value): """ - Recursively set the handle_errors parameter for all compatible estimators. + Recursively set the safe_inference_mode parameter for all compatible estimators. :param estimator: A scikit-learn estimator, pipeline, or custom wrapper - :param value: Boolean value to set for handle_errors + :param value: Boolean value to set for safe_inference_mode """ - def _set_handle_errors_recursive(est, val): - if hasattr(est, "handle_errors"): - est.handle_errors = val + def _set_safe_inference_mode_recursive(est, val): + if hasattr(est, "safe_inference_mode"): + est.safe_inference_mode = val # Handle Pipeline if isinstance(est, Pipeline): for _, step in est.steps: - _set_handle_errors_recursive(step, val) + _set_safe_inference_mode_recursive(step, val) # Handle FeatureUnion elif isinstance(est, FeatureUnion): for _, transformer in est.transformer_list: - _set_handle_errors_recursive(transformer, val) + _set_safe_inference_mode_recursive(transformer, val) # Handle ColumnTransformer elif isinstance(est, ColumnTransformer): for _, transformer, _ in est.transformers: - _set_handle_errors_recursive(transformer, val) + _set_safe_inference_mode_recursive(transformer, val) - # Handle NanGuardWrapper + # Handle SafeInferenceWrapper elif hasattr(est, "estimator") and isinstance(est.estimator, BaseEstimator): - _set_handle_errors_recursive(est.estimator, val) + _set_safe_inference_mode_recursive(est.estimator, val) # Handle other estimators with get_params elif isinstance(est, BaseEstimator): params = est.get_params(deep=False) for param_name, param_value in params.items(): if isinstance(param_value, BaseEstimator): - _set_handle_errors_recursive(param_value, val) + _set_safe_inference_mode_recursive(param_value, val) # Apply the recursive function - _set_handle_errors_recursive(estimator, value) + _set_safe_inference_mode_recursive(estimator, value) # Final check params = estimator.get_params(deep=True) mismatched_params = [ - key.rstrip("__handle_errors") + key.rstrip("__safe_inference_mode") for key, val in params.items() - if key.endswith("__handle_errors") and val != value + if key.endswith("__safe_inference_mode") and val != value ] if mismatched_params: warnings.warn( - f"The following components have 'handle_errors' set to a different value than requested: {mismatched_params}. " + f"The following components have 'safe_inference_mode' set to a different value than requested: {mismatched_params}. " "This could be due to nested estimators that were not properly handled.", UserWarning, ) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 79c4c08..6134778 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -6,47 +6,72 @@ import pandas as pd from functools import wraps import warnings -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils.metaestimators import available_if -from sklearn.base import TransformerMixin + + +class MaskedArrayError(ValueError): + """Raised when a masked array is passed but handle_errors is False.""" + + pass def filter_invalid_rows(fill_value=np.nan, warn_on_invalid=False): def decorator(func): @wraps(func) - def wrapper(obj, X, *args, **kwargs): - if not getattr(obj, "handle_errors", True): - # If handle_errors is False, call the original function without filtering - return func(obj, X, *args, **kwargs) - - valid_mask = np.isfinite(X).all(axis=1) # Find all rows with nan, inf, etc. + def wrapper(obj, X, y=None, *args, **kwargs): + if not getattr(obj, "safe_inference_mode", True): + if isinstance(X, np.ma.MaskedArray) and X.mask.any(): + raise MaskedArrayError( + f"Masked array detected with safe_inference_mode=False and {X.mask.any(axis=1).sum()} filtered rows. " + "Set safe_inference_mode=True to process masked arrays for inference of production models." + ) + return func(obj, X, y, *args, **kwargs) + + if isinstance(X, np.ma.MaskedArray): + mask_invalid = X.mask.any(axis=1) + finite_mask = np.isfinite(X.data).all(axis=1) + valid_mask = ~mask_invalid & finite_mask + else: + valid_mask = np.isfinite(X).all(axis=1) if warn_on_invalid and not np.all(valid_mask): warnings.warn( - f"Invalid data detected in {func.__name__}. This may lead to unexpected results.", + f"SafeInferenceWrapper is in safe_inference_mode during use of {func.__name__} and invalid data detected. " + "This mode is intended for safe inference in production, not for training and evaluation.", UserWarning, ) valid_indices = np.where(valid_mask)[0] reduced_X = X[valid_mask] - result = func(obj, reduced_X, *args, **kwargs) + if y is not None: + reduced_y = y[valid_mask] + else: + reduced_y = None - if result is None: # For methods like fit that return None + result = func(obj, reduced_X, reduced_y, *args, **kwargs) + + if result is None: return None if isinstance(result, np.ndarray): - output = np.full((X.shape[0], result.shape[1]), fill_value) + if result.ndim == 1: + output = np.full(X.shape[0], fill_value) + else: + output = np.full((X.shape[0], result.shape[1]), fill_value) output[valid_indices] = result return output elif isinstance(result, pd.DataFrame): - # Create a DataFrame with NaN values for all rows output = pd.DataFrame(index=range(X.shape[0]), columns=result.columns) - # Fill the valid rows with the result data + output.iloc[valid_indices] = result + return output + elif isinstance(result, pd.Series): + output = pd.Series(index=range(X.shape[0]), dtype=result.dtype) output.iloc[valid_indices] = result return output else: - return result # For methods that return non-array results + return result return wrapper @@ -125,3 +150,71 @@ def fit_transform(self, X, y): @filter_invalid_rows(warn_on_invalid=True) def score(self, X, y): return self.estimator.score(X, y) + + +class SafeInferenceWrapper(BaseEstimator, TransformerMixin): + """ + Wrapper for sklearn estimators to ensure safe inference in production environments. + + This wrapper is designed to be applied to trained models for use in production settings. + While it can be included during model development and training, the safe inference mode + should only be enabled when deploying models for inference in production. + + Parameters: + ----------- + estimator : BaseEstimator + The trained sklearn estimator to be wrapped. + safe_inference_mode : bool, default=False + If True, enables safeguards for handling invalid data during inference. + This should only be set to True when deploying models to production. + replace_value : any, default=np.nan + The value to use for replacing invalid data points. + """ + + def __init__( + self, + estimator: BaseEstimator, + safe_inference_mode: bool = True, + replace_value=np.nan, + ): + self.estimator = estimator + self.safe_inference_mode = safe_inference_mode + self.replace_value = replace_value + + @property + def n_features_in_(self): + return self.estimator.n_features_in_ + + @filter_invalid_rows(warn_on_invalid=True) + def fit(self, X, y=None, **fit_params): + return self.estimator.fit(X, y, **fit_params) + + @available_if(lambda self: hasattr(self.estimator, "predict")) + @filter_invalid_rows() + def predict(self, X, y=None): + return self.estimator.predict(X) + + @available_if(lambda self: hasattr(self.estimator, "predict_proba")) + @filter_invalid_rows() + def predict_proba(self, X, y=None): + return self.estimator.predict_proba(X) + + @available_if(lambda self: hasattr(self.estimator, "decision_function")) + @filter_invalid_rows() + def decision_function(self, X, y=None): + return self.estimator.decision_function(X) + + @available_if(lambda self: hasattr(self.estimator, "transform")) + @filter_invalid_rows() + def transform(self, X, y=None): + return self.estimator.transform(X) + + @available_if(lambda self: hasattr(self.estimator, "fit_transform")) + @filter_invalid_rows(warn_on_invalid=True) + def fit_transform(self, X, y=None, **fit_params): + return self.estimator.fit_transform(X, y, **fit_params) + + @available_if(lambda self: hasattr(self.estimator, "score")) + @filter_invalid_rows(warn_on_invalid=True) + def score(self, X, y=None): + return self.estimator.score(X, y) diff --git a/tests/test_fptransformers.py b/tests/test_fptransformers.py index 290e747..9a9c27a 100644 --- a/tests/test_fptransformers.py +++ b/tests/test_fptransformers.py @@ -415,7 +415,7 @@ def test_AvalonFingerprintTransformer(chiral_mols_list): ) -def test_transform_with_error_handling( +def test_transform_with_safe_inference_mode( mols_with_invalid_container, morgan_transformer, rdkit_transformer, @@ -434,20 +434,20 @@ def test_transform_with_error_handling( secfp_transformer, avalon_transformer, ]: - t.set_params(handle_errors=True) + t.set_params(safe_inference_mode=True) print(type(t)) fps = t.transform(mols_with_invalid_container) assert len(fps) == len(mols_with_invalid_container) # Check that the last row (corresponding to the InvalidMol) contains NaNs - assert np.all(np.isnan(fps[-1])) + assert np.all(fps.mask[-1]) # Check that other rows don't contain NaNs - assert not np.any(np.isnan(fps[:-1])) + assert not np.any(fps.mask[:-1]) -def test_transform_without_error_handling( +def test_transform_without_safe_inference_mode( mols_with_invalid_container, morgan_transformer, rdkit_transformer, @@ -456,6 +456,7 @@ def test_transform_without_error_handling( maccs_transformer, secfp_transformer, avalon_transformer, + # MHFP seem to accept invalid mols and return 0,0,0,0's ): for t in [ morgan_transformer, @@ -466,7 +467,7 @@ def test_transform_without_error_handling( secfp_transformer, avalon_transformer, ]: - t.set_params(handle_errors=False) + t.set_params(safe_inference_mode=False) with pytest.raises( Exception ): # You might want to be more specific about the exception type @@ -475,7 +476,7 @@ def test_transform_without_error_handling( # Add this test to check parallel processing with error handling -def test_transform_parallel_with_error_handling( +def test_transform_parallel_with_safe_inference_mode( mols_with_invalid_container, morgan_transformer, rdkit_transformer, @@ -494,13 +495,16 @@ def test_transform_parallel_with_error_handling( secfp_transformer, avalon_transformer, ]: - t.set_params(handle_errors=True, parallel=True) + t.set_params(safe_inference_mode=True, parallel=True) fps = t.transform(mols_with_invalid_container) assert len(fps) == len(mols_with_invalid_container) - # Check that the last row (corresponding to the InvalidMol) contains NaNs - assert np.all(np.isnan(fps[-1])) + print(fps.mask) + # Check that the last row (corresponding to the InvalidMol) is masked + assert np.all( + fps.mask[-1] + ) # Mask should be true for all elements in the last row - # Check that other rows don't contain NaNs - assert not np.any(np.isnan(fps[:-1])) + # Check that other rows don't contain any masked values + assert not np.any(fps.mask[:-1, :]) From 35c33a299d4c87ceb57323554c3890ee5e92230d Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Wed, 9 Oct 2024 20:59:58 +0200 Subject: [PATCH 32/41] Created test for descriptor transformer and fixed bugs in descriptor transformer and the wrapper. Now its possible to select if nonfinite values should be masked or not masked. --- scikit_mol/descriptors.py | 43 ++++--- scikit_mol/wrapper.py | 133 +++++++++++---------- tests/test_desctransformer.py | 214 +++++++++++++++++++++++++--------- 3 files changed, 257 insertions(+), 133 deletions(-) diff --git a/scikit_mol/descriptors.py b/scikit_mol/descriptors.py index 2a4c79c..b852448 100644 --- a/scikit_mol/descriptors.py +++ b/scikit_mol/descriptors.py @@ -98,14 +98,17 @@ def start_method(self, start_method): ), f"start_method not in allowed methods {allowed_start_methods}" self._start_method = start_method - def _transform_mol(self, mol: Mol) -> List[Any]: - if not mol and self.safe_inference_mode: - return [np.nan] * len(self.desc_list) + def _transform_mol(self, mol: Mol) -> Union[np.ndarray, np.ma.MaskedArray]: + if not mol: + if self.safe_inference_mode: + return np.ma.masked_all(len(self.desc_list)) + else: + raise ValueError("Invalid molecule provided: {mol}") try: - return list(self.calculators.CalcDescriptors(mol)) + return np.array(list(self.calculators.CalcDescriptors(mol))) except Exception as e: if self.safe_inference_mode: - return [np.nan] * len(self.desc_list) + return np.ma.masked_all(len(self.desc_list)) else: raise e @@ -114,13 +117,17 @@ def fit(self, x, y=None): return self @check_transform_input - def _transform(self, x: List[Mol]) -> np.ndarray: - arr = np.zeros((len(x), len(self.desc_list))) - for i, mol in enumerate(x): - arr[i, :] = self._transform_mol(mol) - return arr + def _transform(self, x: List[Mol]) -> Union[np.ndarray, np.ma.MaskedArray]: + if self.safe_inference_mode: + arrays = [self._transform_mol(mol) for mol in x] + return np.ma.array(arrays) + else: + arr = np.zeros((len(x), len(self.desc_list))) + for i, mol in enumerate(x): + arr[i, :] = self._transform_mol(mol) + return arr - def transform(self, x: List[Mol], y=None) -> np.ndarray: + def transform(self, x: List[Mol], y=None) -> Union[np.ndarray, np.ma.MaskedArray]: """Transform a list of molecules into an array of descriptor values Parameters ---------- @@ -131,8 +138,8 @@ def transform(self, x: List[Mol], y=None) -> np.ndarray: Returns ------- - np.array - Descriptors, shape (samples, length of .selected_descriptors ) + Union[np.ndarray, np.ma.MaskedArray] + Descriptors, shape (samples, length of .selected_descriptors) """ if not self.parallel: @@ -148,11 +155,11 @@ def transform(self, x: List[Mol], y=None) -> np.ndarray: with get_context(self.start_method).Pool(processes=n_processes) as pool: params = self.get_params() x_chunks = np.array_split(x, n_chunks) - # x_chunks = [x.reshape(-1, 1) for x in x_chunks] - arrays = pool.map( - parallel_helper, [(params, x) for x in x_chunks] - ) # is the helper function a safer way of handling the picklind and child process communication - arr = np.concatenate(arrays) + arrays = pool.map(parallel_helper, [(params, x) for x in x_chunks]) + if self.safe_inference_mode: + arr = np.ma.concatenate(arrays) + else: + arr = np.concatenate(arrays) return arr diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 6134778..27f5694 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -28,12 +28,19 @@ def wrapper(obj, X, y=None, *args, **kwargs): ) return func(obj, X, y, *args, **kwargs) + # Initialize valid_mask as all True + valid_mask = np.ones(X.shape[0], dtype=bool) + + # Handle masked arrays if isinstance(X, np.ma.MaskedArray): - mask_invalid = X.mask.any(axis=1) - finite_mask = np.isfinite(X.data).all(axis=1) - valid_mask = ~mask_invalid & finite_mask - else: - valid_mask = np.isfinite(X).all(axis=1) + valid_mask &= ~X.mask.any(axis=1) + + # Handle non-finite values if required + if getattr(obj, "mask_nonfinite", True): + if isinstance(X, np.ma.MaskedArray): + valid_mask &= np.isfinite(X.data).all(axis=1) + else: + valid_mask &= np.isfinite(X).all(axis=1) if warn_on_invalid and not np.all(valid_mask): warnings.warn( @@ -78,78 +85,80 @@ def wrapper(obj, X, y=None, *args, **kwargs): return decorator -class NanGuardWrapper(BaseEstimator, TransformerMixin): - """Nan/Inf safe wrapper for sklearn estimator objects.""" +# class NanGuardWrapper(BaseEstimator, TransformerMixin): +# """Nan/Inf safe wrapper for sklearn estimator objects.""" - def __init__( - self, - estimator: BaseEstimator, - handle_errors: bool = False, - replace_value=np.nan, - ): - super().__init__() - self.handle_errors = handle_errors - self.replace_value = replace_value - self.estimator = estimator +# def __init__( +# self, +# estimator: BaseEstimator, +# handle_errors: bool = False, +# replace_value=np.nan, +# mask_nonfinite: bool = True, +# ): +# super().__init__() +# self.handle_errors = handle_errors +# self.replace_value = replace_value +# self.estimator = estimator +# self.mask_nonfinite = mask_nonfinite - def has_predict(self) -> bool: - return hasattr(self.estimator, "predict") +# def has_predict(self) -> bool: +# return hasattr(self.estimator, "predict") - def has_predict_proba(self) -> bool: - return hasattr(self.estimator, "predict_proba") +# def has_predict_proba(self) -> bool: +# return hasattr(self.estimator, "predict_proba") - def has_transform(self) -> bool: - return hasattr(self.estimator, "transform") +# def has_transform(self) -> bool: +# return hasattr(self.estimator, "transform") - def has_fit_transform(self) -> bool: - return hasattr(self.estimator, "fit_transform") +# def has_fit_transform(self) -> bool: +# return hasattr(self.estimator, "fit_transform") - def has_score(self) -> bool: - return hasattr(self.estimator, "score") +# def has_score(self) -> bool: +# return hasattr(self.estimator, "score") - def has_n_features_in_(self) -> bool: - return hasattr(self.estimator, "n_features_in_") +# def has_n_features_in_(self) -> bool: +# return hasattr(self.estimator, "n_features_in_") - def has_decision_function(self) -> bool: - return hasattr(self.estimator, "decision_function") +# def has_decision_function(self) -> bool: +# return hasattr(self.estimator, "decision_function") - @property - def n_features_in_(self) -> int: - return self.estimator.n_features_in_ +# @property +# def n_features_in_(self) -> int: +# return self.estimator.n_features_in_ - @filter_invalid_rows(warn_on_invalid=True) - def fit(self, X, *args, **fit_params) -> Any: - return self.estimator.fit(X, *args, **fit_params) +# @filter_invalid_rows(warn_on_invalid=True) +# def fit(self, X, *args, **fit_params) -> Any: +# return self.estimator.fit(X, *args, **fit_params) - @available_if(has_predict) - @filter_invalid_rows() - def predict(self, X): - return self.estimator.predict(X) +# @available_if(has_predict) +# @filter_invalid_rows() +# def predict(self, X): +# return self.estimator.predict(X) - @available_if(has_decision_function) - @filter_invalid_rows() - def decision_function(self, X): - return self.estimator.decision_function(X) +# @available_if(has_decision_function) +# @filter_invalid_rows() +# def decision_function(self, X): +# return self.estimator.decision_function(X) - @available_if(has_predict_proba) - @filter_invalid_rows() - def predict_proba(self, X): - return self.estimator.predict_proba(X) +# @available_if(has_predict_proba) +# @filter_invalid_rows() +# def predict_proba(self, X): +# return self.estimator.predict_proba(X) - @available_if(has_transform) - @filter_invalid_rows() - def transform(self, X): - return self.estimator.transform(X) +# @available_if(has_transform) +# @filter_invalid_rows() +# def transform(self, X): +# return self.estimator.transform(X) - @available_if(has_fit_transform) - @filter_invalid_rows(warn_on_invalid=True) - def fit_transform(self, X, y): - return self.estimator.fit_transform(X, y) +# @available_if(has_fit_transform) +# @filter_invalid_rows(warn_on_invalid=True) +# def fit_transform(self, X, y): +# return self.estimator.fit_transform(X, y) - @available_if(has_score) - @filter_invalid_rows(warn_on_invalid=True) - def score(self, X, y): - return self.estimator.score(X, y) +# @available_if(has_score) +# @filter_invalid_rows(warn_on_invalid=True) +# def score(self, X, y): +# return self.estimator.score(X, y) class SafeInferenceWrapper(BaseEstimator, TransformerMixin): @@ -176,10 +185,12 @@ def __init__( estimator: BaseEstimator, safe_inference_mode: bool = True, replace_value=np.nan, + mask_nonfinite: bool = True, ): self.estimator = estimator self.safe_inference_mode = safe_inference_mode self.replace_value = replace_value + self.mask_nonfinite = mask_nonfinite @property def n_features_in_(self): diff --git a/tests/test_desctransformer.py b/tests/test_desctransformer.py index 959f9fc..6877def 100644 --- a/tests/test_desctransformer.py +++ b/tests/test_desctransformer.py @@ -1,14 +1,23 @@ import time -import pytest +import pytest import numpy as np import pandas as pd +import numpy.ma as ma from rdkit.Chem import Descriptors import sklearn from packaging.version import Version from scikit_mol.conversions import SmilesToMolTransformer from scikit_mol.descriptors import MolecularDescriptorTransformer from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT -from fixtures import mols_list, smiles_list, mols_container, smiles_container, skip_pandas_output_test +from fixtures import ( + mols_list, + smiles_list, + invalid_smiles_list, + mols_container, + smiles_container, + skip_pandas_output_test, + mols_with_invalid_container, +) from sklearn import clone from sklearn.pipeline import Pipeline import joblib @@ -18,79 +27,89 @@ def default_descriptor_transformer(): return MolecularDescriptorTransformer() + @pytest.fixture def selected_descriptor_transformer(): - return MolecularDescriptorTransformer(desc_list=['HeavyAtomCount', 'FractionCSP3', 'RingCount', 'MolLogP', 'MolWt']) + return MolecularDescriptorTransformer( + desc_list=["HeavyAtomCount", "FractionCSP3", "RingCount", "MolLogP", "MolWt"] + ) + -def test_descriptor_transformer_clonability( default_descriptor_transformer): - for t in [ default_descriptor_transformer]: - params = t.get_params() +def test_descriptor_transformer_clonability(default_descriptor_transformer): + for t in [default_descriptor_transformer]: + params = t.get_params() t2 = clone(t) params_2 = t2.get_params() - #Parameters of cloned transformers should be the same - assert all([ params[key] == params_2[key] for key in params.keys()]) - #Cloned transformers should not be the same object + # Parameters of cloned transformers should be the same + assert all([params[key] == params_2[key] for key in params.keys()]) + # Cloned transformers should not be the same object assert t2 != t + def test_descriptor_transformer_set_params(default_descriptor_transformer): for t in [default_descriptor_transformer]: - params = t.get_params() - #change extracted dictionary - params['desc_list'] = ['HeavyAtomCount', 'FractionCSP3'] - #change params in transformer - t.set_params(desc_list = ['HeavyAtomCount', 'FractionCSP3']) + params = t.get_params() + # change extracted dictionary + params["desc_list"] = ["HeavyAtomCount", "FractionCSP3"] + # change params in transformer + t.set_params(desc_list=["HeavyAtomCount", "FractionCSP3"]) # get parameters as dictionary and assert that it is the same params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) assert len(default_descriptor_transformer.selected_descriptors) == 2 -def test_descriptor_transformer_available_descriptors(default_descriptor_transformer, selected_descriptor_transformer): - #Default have as many as in RDkit and all are selected - assert (len(default_descriptor_transformer.available_descriptors) == len(Descriptors._descList)) - assert (len(default_descriptor_transformer.selected_descriptors) == len(Descriptors._descList)) - #Default have as many as in RDkit but only 5 are selected - assert (len(selected_descriptor_transformer.available_descriptors) == len(Descriptors._descList)) - assert (len(selected_descriptor_transformer.selected_descriptors) == 5) - -def test_descriptor_transformer_transform(mols_container, default_descriptor_transformer): +def test_descriptor_transformer_available_descriptors( + default_descriptor_transformer, selected_descriptor_transformer +): + # Default have as many as in RDkit and all are selected + assert len(default_descriptor_transformer.available_descriptors) == len( + Descriptors._descList + ) + assert len(default_descriptor_transformer.selected_descriptors) == len( + Descriptors._descList + ) + # Default have as many as in RDkit but only 5 are selected + assert len(selected_descriptor_transformer.available_descriptors) == len( + Descriptors._descList + ) + assert len(selected_descriptor_transformer.selected_descriptors) == 5 + + +def test_descriptor_transformer_transform( + mols_container, default_descriptor_transformer +): features = default_descriptor_transformer.transform(mols_container) - assert(len(features) == len(mols_container)) - assert(len(features[0]) == len(Descriptors._descList)) - + assert len(features) == len(mols_container) + assert len(features[0]) == len(Descriptors._descList) + + def test_descriptor_transformer_wrong_descriptors(): with pytest.raises(AssertionError): - MolecularDescriptorTransformer(desc_list=['Color', 'Icecream content', 'ChokolateDarkness', 'Content42', 'MolWt']) - + MolecularDescriptorTransformer( + desc_list=[ + "Color", + "Icecream content", + "ChokolateDarkness", + "Content42", + "MolWt", + ] + ) def test_descriptor_transformer_parallel(mols_list, default_descriptor_transformer): default_descriptor_transformer.set_params(parallel=True) features = default_descriptor_transformer.transform(mols_list) - assert(len(features) == len(mols_list)) - assert(len(features[0]) == len(Descriptors._descList)) - #Now with Rdkit 2022.3 creating a second transformer and running it, froze the process - transformer2 = MolecularDescriptorTransformer(**default_descriptor_transformer.get_params()) + assert len(features) == len(mols_list) + assert len(features[0]) == len(Descriptors._descList) + # Now with Rdkit 2022.3 creating a second transformer and running it, froze the process + transformer2 = MolecularDescriptorTransformer( + **default_descriptor_transformer.get_params() + ) features2 = transformer2.transform(mols_list) - assert(len(features2) == len(mols_list)) - assert(len(features2[0]) == len(Descriptors._descList)) - - -@skip_pandas_output_test -def test_descriptor_transformer_pandas_output(mols_container, default_descriptor_transformer, selected_descriptor_transformer, pandas_output): - for transformer in [default_descriptor_transformer, selected_descriptor_transformer]: - features = transformer.transform(mols_container) - assert isinstance(features, pd.DataFrame) - assert features.shape[0] == len(mols_container) - assert features.columns.tolist() == transformer.selected_descriptors + assert len(features2) == len(mols_list) + assert len(features2[0]) == len(Descriptors._descList) -@skip_pandas_output_test -def test_descriptor_transformer_pandas_output_pipeline(smiles_container, default_descriptor_transformer, pandas_output): - pipeline = Pipeline([("s2m", SmilesToMolTransformer()), ("desc", default_descriptor_transformer)]) - features = pipeline.fit_transform(smiles_container) - assert isinstance(features, pd.DataFrame) - assert features.shape[0] == len(smiles_container) - assert features.columns.tolist() == default_descriptor_transformer.selected_descriptors # This test may fail on windows and mac (due to spawn rather than fork?) # def test_descriptor_transformer_parallel_speedup(mols_list, default_descriptor_transformer): @@ -100,7 +119,7 @@ def test_descriptor_transformer_pandas_output_pipeline(smiles_container, default # t0 = time.time() # features = default_descriptor_transformer.transform(mols_list) # t_single = time.time()-t0 - + # default_descriptor_transformer.set_params(parallel=True) # t0 = time.time() # features = default_descriptor_transformer.transform(mols_list) @@ -108,7 +127,94 @@ def test_descriptor_transformer_pandas_output_pipeline(smiles_container, default # assert(t_par < t_single/(n_phys_cpus/1.5)) # div by 1.5 as we don't assume full speedup - - +def test_transform_with_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer(safe_inference_mode=True) + descriptors = transformer.transform(mols_with_invalid_container) + + assert isinstance(descriptors, ma.MaskedArray) + assert len(descriptors) == len(mols_with_invalid_container) + + # Check that the last row (corresponding to the InvalidMol) is fully masked + assert np.all(descriptors.mask[-1]) + + # Check that other rows are not masked + assert not np.any(descriptors.mask[:-1]) + + +def test_transform_without_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer(safe_inference_mode=False) + with pytest.raises( + Exception + ): # You might want to be more specific about the exception type + transformer.transform(mols_with_invalid_container) + +def test_transform_parallel_with_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer( + safe_inference_mode=True, parallel=True + ) + descriptors = transformer.transform(mols_with_invalid_container) + + assert isinstance(descriptors, ma.MaskedArray) + assert len(descriptors) == len(mols_with_invalid_container) + + # Check that the last row (corresponding to the InvalidMol) is fully masked + assert np.all(descriptors.mask[-1]) + + # Check that other rows are not masked + assert not np.any(descriptors.mask[:-1]) + + +def test_transform_parallel_without_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer( + safe_inference_mode=False, parallel=True + ) + with pytest.raises( + Exception + ): # You might want to be more specific about the exception type + transformer.transform(mols_with_invalid_container) + + +def test_safe_inference_mode_setting(): + transformer = MolecularDescriptorTransformer() + assert not transformer.safe_inference_mode # Default should be False + + transformer.set_params(safe_inference_mode=True) + assert transformer.safe_inference_mode + + transformer.set_params(safe_inference_mode=False) + assert not transformer.safe_inference_mode + + +# TODO, if these tests are run before the others, these tests will fail, probably due to pandas output? +@skip_pandas_output_test +def test_descriptor_transformer_pandas_output( + mols_container, + default_descriptor_transformer, + selected_descriptor_transformer, + pandas_output, +): + for transformer in [ + default_descriptor_transformer, + selected_descriptor_transformer, + ]: + features = transformer.transform(mols_container) + assert isinstance(features, pd.DataFrame) + assert features.shape[0] == len(mols_container) + assert features.columns.tolist() == transformer.selected_descriptors + + +@skip_pandas_output_test +def test_descriptor_transformer_pandas_output_pipeline( + smiles_container, default_descriptor_transformer, pandas_output +): + pipeline = Pipeline( + [("s2m", SmilesToMolTransformer()), ("desc", default_descriptor_transformer)] + ) + features = pipeline.fit_transform(smiles_container) + assert isinstance(features, pd.DataFrame) + assert features.shape[0] == len(smiles_container) + assert ( + features.columns.tolist() == default_descriptor_transformer.selected_descriptors + ) From 8b85989e9031bbb6a3bd243849b1f80cc00ed4f6 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Thu, 10 Oct 2024 20:29:13 +0200 Subject: [PATCH 33/41] Added dtype directly as properties on the objects, not on the class. --- scikit_mol/descriptors.py | 6 ++-- scikit_mol/fingerprints.py | 64 +++++++++++++++++++++++++------------- scikit_mol/standardizer.py | 7 +++-- scikit_mol/wrapper.py | 2 +- tests/test_smilestomol.py | 10 +++--- 5 files changed, 58 insertions(+), 31 deletions(-) diff --git a/scikit_mol/descriptors.py b/scikit_mol/descriptors.py index b852448..e82bfbd 100644 --- a/scikit_mol/descriptors.py +++ b/scikit_mol/descriptors.py @@ -43,11 +43,13 @@ def __init__( parallel: Union[bool, int] = False, start_method: str = None, # "fork", safe_inference_mode: bool = False, + dtype: np.dtype = np.float32, ): self.desc_list = desc_list self.parallel = parallel self.start_method = start_method self.safe_inference_mode = safe_inference_mode + self.dtype = dtype def _get_desc_calculator(self) -> MolecularDescriptorCalculator: if self.desc_list: @@ -120,9 +122,9 @@ def fit(self, x, y=None): def _transform(self, x: List[Mol]) -> Union[np.ndarray, np.ma.MaskedArray]: if self.safe_inference_mode: arrays = [self._transform_mol(mol) for mol in x] - return np.ma.array(arrays) + return np.ma.array(arrays, dtype=self.dtype) else: - arr = np.zeros((len(x), len(self.desc_list))) + arr = np.zeros((len(x), len(self.desc_list)), dtype=self.dtype) for i, mol in enumerate(x): arr[i, :] = self._transform_mol(mol) return arr diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 86274f6..f044a06 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -33,14 +33,12 @@ def __init__( parallel: Union[bool, int] = False, start_method: str = None, safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): self.parallel = parallel self.start_method = start_method self.safe_inference_mode = safe_inference_mode - - # The dtype of the fingerprint array computed by the transformer - # If needed this property can be overwritten in the child class. - _DTYPE_FINGERPRINT = np.int8 + self.dtype = dtype def _get_column_prefix(self) -> str: matched = _PATTERN_FINGERPRINT_TRANSFORMER.match(type(self).__name__) @@ -85,11 +83,11 @@ def _mol2fp(self, mol): def _fp2array(self, fp): if fp: - arr = np.zeros((self.nBits,), dtype=self._DTYPE_FINGERPRINT) + arr = np.zeros((self.nBits,), dtype=self.dtype) DataStructs.ConvertToNumpyArray(fp, arr) return arr else: - return np.ma.masked_all((self.nBits,), dtype=self._DTYPE_FINGERPRINT) + return np.ma.masked_all((self.nBits,), dtype=self.dtype) def _transform_mol(self, mol): if not mol and self.safe_inference_mode: @@ -118,13 +116,13 @@ def _transform(self, X): return np.ma.stack(arrays) else: # Use the original, faster method if we're not in safe inference mode - arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) + arr = np.zeros((len(X), self.nBits), dtype=self.dtype) for i, mol in enumerate(X): arr[i, :] = self._transform_mol(mol) return arr def _transform_sparse(self, X): - arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) + arr = np.zeros((len(X), self.nBits), dtype=self.dtype) for i, mol in enumerate(X): arr[i, :] = self._transform_mol(mol) @@ -180,12 +178,17 @@ def transform(self, X, y=None): class MACCSKeysFingerprintTransformer(FpsTransformer): def __init__( - self, parallel: Union[bool, int] = False, safe_inference_mode: bool = False + self, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): """MACCS keys fingerprinter calculates the 167 fixed MACCS keys """ - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.nBits = 167 @property @@ -219,6 +222,7 @@ def __init__( atomInvariantsGenerator=None, parallel: Union[bool, int] = False, safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): """Calculates the RDKit fingerprints @@ -245,7 +249,9 @@ def __init__( atomInvariantsGenerator : _type_, optional atom invariants to be used during fingerprint generation, by default None """ - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.minPath = minPath self.maxPath = maxPath self.useHs = useHs @@ -298,8 +304,11 @@ def __init__( useCounts: bool = False, parallel: Union[bool, int] = False, safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.minLength = minLength self.maxLength = maxLength self.fromAtoms = fromAtoms @@ -355,8 +364,11 @@ def __init__( useCounts: bool = False, parallel: Union[bool, int] = False, safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.targetSize = targetSize self.fromAtoms = fromAtoms self.ignoreAtoms = ignoreAtoms @@ -391,10 +403,6 @@ def _mol2fp(self, mol): class MHFingerprintTransformer(FpsTransformer): - # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 - - _DTYPE_FINGERPRINT = np.int32 - def __init__( self, radius: int = 3, @@ -406,9 +414,12 @@ def __init__( seed: int = 42, parallel: Union[bool, int] = False, safe_inference_mode: bool = False, + dtype: np.dtype = np.int32, ): """Transforms the RDKit mol into the MinHash fingerprint (MHFP) + https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 + Args: radius (int, optional): The MHFP radius. Defaults to 3. rings (bool, optional): Whether or not to include rings in the shingling. Defaults to True. @@ -419,7 +430,9 @@ def __init__( this is effectively the length of the FP seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.radius = radius self.rings = rings self.isomeric = isomeric @@ -498,6 +511,7 @@ def __init__( seed: int = 0, parallel: Union[bool, int] = False, safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): """Transforms the RDKit mol into the SMILES extended connectivity fingerprint (SECFP) @@ -511,7 +525,9 @@ def __init__( n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0. seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.radius = radius self.rings = rings self.isomeric = isomeric @@ -590,6 +606,7 @@ def __init__( useCounts=False, parallel: Union[bool, int] = False, safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): """Transform RDKit mols into Count or bit-based hashed MorganFingerprints @@ -608,7 +625,9 @@ def __init__( useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.nBits = nBits self.radius = radius self.useChirality = useChirality @@ -648,6 +667,7 @@ def __init__( useCounts: bool = False, parallel: Union[bool, int] = False, safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, ): """Transform RDKit mols into Count or bit-based Avalon Fingerprints @@ -664,7 +684,9 @@ def __init__( useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.nBits = nBits self.isQuery = isQuery self.resetVect = resetVect diff --git a/scikit_mol/standardizer.py b/scikit_mol/standardizer.py index 29a7722..76a8c55 100644 --- a/scikit_mol/standardizer.py +++ b/scikit_mol/standardizer.py @@ -26,9 +26,12 @@ def fit(self, X, y=None): def _standardize_mol(self, mol): if not mol: if self.safe_inference_mode: - return InvalidMol(str(self), "Invalid input molecule") + if isinstance(mol, InvalidMol): + return mol + else: + return InvalidMol(str(self), f"Invalid input molecule: {mol}") else: - raise ValueError("Invalid input molecule") + raise ValueError(f"Invalid input molecule: {mol}") try: block = BlockLogs() # Block all RDkit logging diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 27f5694..4209038 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -11,7 +11,7 @@ class MaskedArrayError(ValueError): - """Raised when a masked array is passed but handle_errors is False.""" + """Raised when a masked array is passed but safe_inference_mode is False.""" pass diff --git a/tests/test_smilestomol.py b/tests/test_smilestomol.py index 737080d..19bf288 100644 --- a/tests/test_smilestomol.py +++ b/tests/test_smilestomol.py @@ -84,7 +84,7 @@ def test_smilestomol_inverse_transform(smilestomol_transformer, smiles_container def test_smilestomol_inverse_transform_with_invalid( invalid_smiles_list, smilestomol_transformer ): - smilestomol_transformer.set_params(handle_errors=True) + smilestomol_transformer.set_params(safe_inference_mode=True) # Forward transform mols = smilestomol_transformer.transform(invalid_smiles_list) @@ -110,8 +110,8 @@ def test_smilestomol_get_feature_names_out(smilestomol_transformer): assert feature_names == [DEFAULT_MOL_COLUMN_NAME] -def test_smilestomol_handle_errors(invalid_smiles_list, smilestomol_transformer): - smilestomol_transformer.set_params(handle_errors=True) +def test_smilestomol_safe_inference(invalid_smiles_list, smilestomol_transformer): + smilestomol_transformer.set_params(safe_inference_mode=True) result = smilestomol_transformer.transform(invalid_smiles_list) assert len(result) == len(invalid_smiles_list) @@ -135,10 +135,10 @@ def test_smilestomol_handle_errors(invalid_smiles_list, smilestomol_transformer) not skip_pandas_output_test, reason="Pandas output not supported in this sklearn version", ) -def test_smilestomol_handle_errors_pandas_output( +def test_smilestomol_safe_inference_pandas_output( invalid_smiles_list, smilestomol_transformer, pandas_output ): - smilestomol_transformer.set_params(handle_errors=True) + smilestomol_transformer.set_params(safe_inference_mode=True) result = smilestomol_transformer.transform(invalid_smiles_list) assert len(result) == len(invalid_smiles_list) From d074d2cb55b13ede8e0a576470bda6be30a9c47b Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Thu, 10 Oct 2024 20:32:49 +0200 Subject: [PATCH 34/41] Removed deprecated test --- tests/test_invalid_handling.py | 51 ---------------------------------- 1 file changed, 51 deletions(-) delete mode 100644 tests/test_invalid_handling.py diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py deleted file mode 100644 index dfa327b..0000000 --- a/tests/test_invalid_handling.py +++ /dev/null @@ -1,51 +0,0 @@ -import numpy as np -import pytest -from sklearn.decomposition import PCA -from sklearn.pipeline import Pipeline - -from fixtures import smiles_list, invalid_smiles_list -from scikit_mol.conversions import SmilesToMolTransformer -from scikit_mol.fingerprints import ( - MorganFingerprintTransformer, - MACCSKeysFingerprintTransformer, -) -from scikit_mol.wrapper import NanGuardWrapper # WrappedTransformer - -# from scikit_mol._invalid import NumpyArrayWithInvalidInstances -# from test_invalid_helpers.invalid_transformer import TestInvalidTransformer - - -@pytest.fixture -def smilestofp_pipeline(): - pipeline = Pipeline( - [ - ("smiles_to_mol", SmilesToMolTransformer(handle_errors=True)), - ("mol_2_fp", MACCSKeysFingerprintTransformer(handle_errors=True)), - ("PCA", NanGuardWrapper(PCA(2), handle_errors=True)), - ] - ) - return pipeline - - -def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): - # smilestofp_pipeline.set_params() - mol_pca = smilestofp_pipeline.fit_transform(smiles_list) - error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) - - print(mol_pca.shape) - assert mol_pca.shape == ( - len(smiles_list), - 2, - ), "The PCA does not return the proper dimensions." - - expected_nans = np.array([[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]]).T - if not np.all(np.equal(expected_nans, np.isnan(error_mol_pca))): - raise ValueError("Errors were replaced on the wrong positions.") - - non_nan_rows = ~np.any(np.isnan(error_mol_pca), axis=1) - assert np.all( - np.isclose(mol_pca, error_mol_pca[non_nan_rows, :]) - ), "Removing errors introduces changes in the PCA output." - - # TODO, test with and without error handling on - # TODO, test with other transformers From 973dea157beb5d5c5ca73e4dfd5cbcb364f57840 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Thu, 10 Oct 2024 21:45:35 +0200 Subject: [PATCH 35/41] Some minor fixes and updates of messages to be more concise --- scikit_mol/conversions.py | 8 +++++--- scikit_mol/descriptors.py | 2 +- scikit_mol/wrapper.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 0f5b11d..62520b7 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -94,18 +94,20 @@ def _transform(self, X): if mol: X_out.append(mol) else: - mol = Chem.MolFromSmiles(smiles, sanitize=False) + mol = Chem.MolFromSmiles( + smiles, sanitize=False + ) # TODO We could maybe convert mol, and then use Chem.SanitizeMol to get the error message from the sanitizer, and only parse once? if mol: errors = Chem.DetectChemistryProblems(mol) error_message = "\n".join(error.Message() for error in errors) - message = f"Invalid SMILES: {error_message}" + message = f"Invalid Molecule: {error_message}" else: message = f"Invalid SMILES: {smiles}" X_out.append(InvalidMol(str(self), message)) if not self.safe_inference_mode and not all(X_out): fails = [x for x in X_out if not x] raise ValueError( - f"Invalid SMILES found: {fails}." + f"Invalid input found: {fails}." ) # TODO with this approach we get all errors, but we do process ALL the smiles first which could be slow return np.array(X_out).reshape(-1, 1) diff --git a/scikit_mol/descriptors.py b/scikit_mol/descriptors.py index e82bfbd..905a098 100644 --- a/scikit_mol/descriptors.py +++ b/scikit_mol/descriptors.py @@ -105,7 +105,7 @@ def _transform_mol(self, mol: Mol) -> Union[np.ndarray, np.ma.MaskedArray]: if self.safe_inference_mode: return np.ma.masked_all(len(self.desc_list)) else: - raise ValueError("Invalid molecule provided: {mol}") + raise ValueError(f"Invalid molecule provided: {mol}") try: return np.array(list(self.calculators.CalcDescriptors(mol))) except Exception as e: diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 4209038..7ee9494 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -183,7 +183,7 @@ class SafeInferenceWrapper(BaseEstimator, TransformerMixin): def __init__( self, estimator: BaseEstimator, - safe_inference_mode: bool = True, + safe_inference_mode: bool = False, replace_value=np.nan, mask_nonfinite: bool = True, ): From 48d423364d8c136beb92e0bd5390055e850f8b3a Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 11 Oct 2024 16:24:14 +0200 Subject: [PATCH 36/41] Fixed double parsing of SMILES --- scikit_mol/conversions.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 62520b7..da206f5 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -2,9 +2,11 @@ import multiprocessing from typing import Union from rdkit import Chem +from rdkit.rdBase import BlockLogs import numpy as np from sklearn.base import BaseEstimator, TransformerMixin +from torch import Block from scikit_mol.core import ( check_transform_input, @@ -89,21 +91,21 @@ def transform(self, X_smiles_list, y=None): @check_transform_input def _transform(self, X): X_out = [] - for smiles in X: - mol = Chem.MolFromSmiles(smiles) - if mol: - X_out.append(mol) - else: - mol = Chem.MolFromSmiles( - smiles, sanitize=False - ) # TODO We could maybe convert mol, and then use Chem.SanitizeMol to get the error message from the sanitizer, and only parse once? + with BlockLogs(): + for smiles in X: + mol = Chem.MolFromSmiles(smiles, sanitize=False) if mol: errors = Chem.DetectChemistryProblems(mol) - error_message = "\n".join(error.Message() for error in errors) - message = f"Invalid Molecule: {error_message}" + if errors: + error_message = "\n".join(error.Message() for error in errors) + message = f"Invalid Molecule: {error_message}" + X_out.append(InvalidMol(str(self), message)) + else: + Chem.SanitizeMol(mol) + X_out.append(mol) else: message = f"Invalid SMILES: {smiles}" - X_out.append(InvalidMol(str(self), message)) + X_out.append(InvalidMol(str(self), message)) if not self.safe_inference_mode and not all(X_out): fails = [x for x in X_out if not x] raise ValueError( From a6076d1155c2985a4ec7e865542ae8a4952e500e Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 11 Oct 2024 16:48:45 +0200 Subject: [PATCH 37/41] Fixed an issue with pandas output --- scikit_mol/wrapper.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 7ee9494..d4d45f9 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -7,6 +7,7 @@ from functools import wraps import warnings from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import check_array from sklearn.utils.metaestimators import available_if @@ -53,6 +54,7 @@ def wrapper(obj, X, y=None, *args, **kwargs): reduced_X = X[valid_mask] if y is not None: + y = check_array(y, force_all_finite=False) reduced_y = y[valid_mask] else: reduced_y = None @@ -229,3 +231,8 @@ def fit_transform(self, X, y=None, **fit_params): @filter_invalid_rows(warn_on_invalid=True) def score(self, X, y=None): return self.estimator.score(X, y) + + @available_if(lambda self: hasattr(self.estimator, "get_feature_names_out")) + @filter_invalid_rows(warn_on_invalid=True) + def get_feature_names_out(self, *args, **kwargs): + return self.estimator.get_feature_names_out(*args, **kwargs) From 08dc417cfd4ea408e36c582fcfea49cefe53335a Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Sun, 13 Oct 2024 16:57:23 +0200 Subject: [PATCH 38/41] Created a test and updated name of the module for safeinference --- scikit_mol/{wrapper.py => safeinference.py} | 12 +- tests/test_safeinferencemode.py | 115 ++++++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) rename scikit_mol/{wrapper.py => safeinference.py} (95%) create mode 100644 tests/test_safeinferencemode.py diff --git a/scikit_mol/wrapper.py b/scikit_mol/safeinference.py similarity index 95% rename from scikit_mol/wrapper.py rename to scikit_mol/safeinference.py index d4d45f9..59b922c 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/safeinference.py @@ -10,6 +10,8 @@ from sklearn.utils import check_array from sklearn.utils.metaestimators import available_if +from .utilities import set_safe_inference_mode + class MaskedArrayError(ValueError): """Raised when a masked array is passed but safe_inference_mode is False.""" @@ -54,7 +56,15 @@ def wrapper(obj, X, y=None, *args, **kwargs): reduced_X = X[valid_mask] if y is not None: - y = check_array(y, force_all_finite=False) + # TODO, how can we check y in the same way as the estimator? + y = check_array( + y, + force_all_finite=False, # accept_sparse="csr", + ensure_2d=False, + dtype=None, + input_name="y", + estimator=obj, + ) reduced_y = y[valid_mask] else: reduced_y = None diff --git a/tests/test_safeinferencemode.py b/tests/test_safeinferencemode.py new file mode 100644 index 0000000..921cc0f --- /dev/null +++ b/tests/test_safeinferencemode.py @@ -0,0 +1,115 @@ +import pytest +import numpy as np +import pandas as pd +from sklearn.pipeline import Pipeline +from sklearn.ensemble import RandomForestRegressor +from scikit_mol.conversions import SmilesToMolTransformer +from scikit_mol.fingerprints import MorganFingerprintTransformer +from scikit_mol.safeinference import SafeInferenceWrapper +from scikit_mol.utilities import set_safe_inference_mode + +from fixtures import ( + SLC6A4_subset, + invalid_smiles_list, + skip_pandas_output_test, + smiles_list, +) + + +@pytest.fixture +def smiles_pipeline(): + return Pipeline( + [ + ("s2m", SmilesToMolTransformer()), + ("FP", MorganFingerprintTransformer()), + ( + "RF", + SafeInferenceWrapper( + RandomForestRegressor(n_estimators=3, random_state=42) + ), + ), + ] + ) + + +def test_safeinference_wrapper_basic(smiles_pipeline, SLC6A4_subset): + X_smiles, Y = SLC6A4_subset.SMILES, SLC6A4_subset.pXC50 + X_smiles = X_smiles.to_frame() + + # Set safe inference mode + set_safe_inference_mode(smiles_pipeline, True) + + # Train the model + smiles_pipeline.fit(X_smiles, Y) + + # Test prediction + predictions = smiles_pipeline.predict(X_smiles) + assert len(predictions) == len(X_smiles) + assert not np.any(np.isnan(predictions)) + + +def test_safeinference_wrapper_with_invalid_smiles( + smiles_pipeline, SLC6A4_subset, invalid_smiles_list +): + X_smiles, Y = SLC6A4_subset.SMILES[:100], SLC6A4_subset.pXC50[:100] + X_smiles = X_smiles.to_frame() + + # Set safe inference mode + set_safe_inference_mode(smiles_pipeline, True) + + # Train the model + smiles_pipeline.fit(X_smiles, Y) + + # Create a test set with invalid SMILES + X_test = pd.DataFrame({"SMILES": X_smiles["SMILES"].tolist() + invalid_smiles_list}) + + # Test prediction with invalid SMILES + predictions = smiles_pipeline.predict(X_test) + assert len(predictions) == len(X_test) + assert np.any(np.isnan(predictions)) + assert np.all(np.isnan(predictions[-1])) # Only last should be nan + assert np.all(~np.isnan(predictions[:-1])) # All others should not be nan + + +def test_safeinference_wrapper_without_safe_mode( + smiles_pipeline, SLC6A4_subset, invalid_smiles_list +): + X_smiles, Y = SLC6A4_subset.SMILES[:100], SLC6A4_subset.pXC50[:100] + X_smiles = X_smiles.to_frame() + + # Ensure safe inference mode is off (default behavior) + set_safe_inference_mode(smiles_pipeline, False) + + # Train the model + smiles_pipeline.fit(X_smiles, Y) + + # Create a test set with invalid SMILES + X_test = pd.DataFrame({"SMILES": X_smiles["SMILES"].tolist() + invalid_smiles_list}) + + # Test prediction with invalid SMILES + with pytest.raises(Exception): + smiles_pipeline.predict(X_test) + + +@skip_pandas_output_test +def test_safeinference_wrapper_pandas_output( + smiles_pipeline, SLC6A4_subset, pandas_output +): + X_smiles = SLC6A4_subset.SMILES[:100].to_frame() + + # Set safe inference mode + set_safe_inference_mode(smiles_pipeline, True) + + # Fit and transform (up to the FP step) + result = smiles_pipeline[:-1].fit_transform(X_smiles) + assert isinstance(result, pd.DataFrame) + assert result.shape[0] == len(X_smiles) + assert result.shape[1] == smiles_pipeline.named_steps["FP"].nBits + + +@skip_pandas_output_test +def test_safeinference_wrapper_get_feature_names_out(smiles_pipeline): + # Get feature names from the FP step + feature_names = smiles_pipeline.named_steps["FP"].get_feature_names_out() + assert len(feature_names) == smiles_pipeline.named_steps["FP"].nBits + assert all(isinstance(name, str) for name in feature_names) From a735b7cfbcb9f35625963a56e4199783448fd0eb Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Sun, 13 Oct 2024 17:08:05 +0200 Subject: [PATCH 39/41] Updated Notebook and links to notebook in readme (will first work when commiting to main on GitHub) --- README.md | 1 + notebooks/11_safe_inference.ipynb | 1023 +++++++++++++++++++++++++++++ notebooks/11_safe_inference.py | 145 ++++ notebooks/README.md | 1 + 4 files changed, 1170 insertions(+) create mode 100644 notebooks/11_safe_inference.ipynb create mode 100644 notebooks/11_safe_inference.py diff --git a/README.md b/README.md index 927e41c..453129a 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ There are a collection of notebooks in the notebooks directory which demonstrate - [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb) - [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb) - [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb) +- [Working with pipelines and estimators in safe inference mode for handling prediction on batches with invalid smiles or molecules](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb) We also put a software note on ChemRxiv. [https://doi.org/10.26434/chemrxiv-2023-fzqwd](https://doi.org/10.26434/chemrxiv-2023-fzqwd) diff --git a/notebooks/11_safe_inference.ipynb b/notebooks/11_safe_inference.ipynb new file mode 100644 index 0000000..6ee786e --- /dev/null +++ b/notebooks/11_safe_inference.ipynb @@ -0,0 +1,1023 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Safe inference mode\n", + "\n", + "I think everyone which have worked with SMILES and RDKit sooner or later come across a SMILES that doesn't parse. It can happen if the SMILES was produced with a different toolkit that are less strict with e.g. valence rules, or maybe a characher was missing in the copying from the email. During curation of the dataset for training models, these SMILES need to be identfied and eventually fixed or removed. But what happens when we are finished with our modelling? What kind of molecules and SMILES will a user of the model send for the model in the future when it's in deployment. What kind of SMILES will a generative model create that we need to predict? We don't know and we won't know. So it's kind of crucial to be able to handle these situations. Scikit-Learn models usually simply explodes the entire batch that are being predicted. This is where safe_inference_mode was introduced in Scikit-Mol. With the introduction all transformers got a safe inference mode, where they handle invalid input. How they handle it depends a bit on the transformer, so we will go through the different usual steps and see how things have changed with the introduction of the safe inference mode.\n", + "\n", + "NOTE! In the following demonstration I switch on the safe inference mode individually for demonstration purposes. I would not recommend to do that while building and training models, instead I would switch it on _after_ training and evaluation (more on that later). Otherwise there's a risk to train on the 2% of a dataset that didn't fail....\n", + "\n", + "First some imports and test SMILES and molecules." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[],\n", + " [],\n", + " [],\n", + " [],\n", + " [InvalidMol('SmilesToMolTransformer(safe_inference_mode=True)', error='Invalid Molecule: Explicit valence for atom # 0 N, 4, is greater than permitted')],\n", + " [InvalidMol('SmilesToMolTransformer(safe_inference_mode=True)', error='Invalid SMILES: I'm not a SMILES')]],\n", + " dtype=object)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from rdkit import Chem\n", + "from scikit_mol.conversions import SmilesToMolTransformer\n", + "\n", + "#We have some deprecation warnings, we are adressing them, but they just distract from this demonstration\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning) \n", + "\n", + "smiles = [\"C1=CC=C(C=C1)F\", \"C1=CC=C(C=C1)O\", \"C1=CC=C(C=C1)N\", \"C1=CC=C(C=C1)Cl\"]\n", + "smiles_with_invalid = smiles + [\"N(C)(C)(C)C\", \"I'm not a SMILES\"]\n", + "\n", + "smi2mol = SmilesToMolTransformer(safe_inference_mode=True)\n", + "\n", + "mols_with_invalid = smi2mol.transform(smiles_with_invalid)\n", + "mols_with_invalid" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Without the safe inference mode, the transformation would simply fail, but now we get the expected array back with our RDKit molecules and a last entry which is an object of the type InvalidMol. InvalidMol is simply a placeholder that tells what step failed the conversion and the error. InvalidMol evaluates to `False` in boolean contexts, so it gets easy to filter away and handle in `if`s and list comprehensions. As example:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([], dtype=object),\n", + " array([], dtype=object),\n", + " array([], dtype=object),\n", + " array([], dtype=object)]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[mol for mol in mols_with_invalid if mol]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "or" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([,\n", + " ,\n", + " ,\n", + " ], dtype=object)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask = mols_with_invalid.astype(bool)\n", + "mols_with_invalid[mask]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Having a failsafe SmilesToMol conversion leads us to next step, featurization. The transformers in safe inference mode now return a NumPy masked array instead of a regular NumPy array. It simply evaluates the incoming mols in a boolean context, so e.g. `None`, `np.nan` and other Python objects that evaluates to False will also get masked (i.e. if you use a dataframe with an ROMol column produced with the PandasTools utility)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/plain": [ + "masked_array(\n", + " data=[[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1,\n", + " 0, 1, 1, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1,\n", + " 0, 0, 1, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1,\n", + " 0, 0, 0, 0],\n", + " [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1,\n", + " 0, 1, 0, 1],\n", + " [--, --, --, --, --, --, --, --, --, --, --, --, --, --, --, --,\n", + " --, --, --, --, --, --, --, --, --],\n", + " [--, --, --, --, --, --, --, --, --, --, --, --, --, --, --, --,\n", + " --, --, --, --, --, --, --, --, --]],\n", + " mask=[[False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [ True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True],\n", + " [ True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True]],\n", + " fill_value=999999,\n", + " dtype=int8)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from scikit_mol.fingerprints import MorganFingerprintTransformer\n", + "\n", + "mfp = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True)\n", + "fps = mfp.transform(mols_with_invalid)\n", + "fps\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, currently scikit-learn models accepts masked arrays, but they do not respect the mask! So if you fed it directly to the model to train, it would seemingly work, but the invalid samples would all have the fill_value, meaning you could get weird results. Instead we need the last part of the puzzle, the SafeInferenceWrapper class." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/esben/git/scikit-mol/scikit_mol/safeinference.py:49: UserWarning: SafeInferenceWrapper is in safe_inference_mode during use of fit and invalid data detected. This mode is intended for safe inference in production, not for training and evaluation.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "array([ 0., 1., 0., 1., nan, nan])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from scikit_mol.safeinference import SafeInferenceWrapper\n", + "from sklearn.linear_model import LogisticRegression\n", + "import numpy as np\n", + "\n", + "regressor = LogisticRegression()\n", + "wrapper = SafeInferenceWrapper(regressor, safe_inference_mode=True)\n", + "wrapper.fit(fps, [0,1,0,1,0,1])\n", + "wrapper.predict(fps)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The prediction went fine both in fit and in prediction, where the result shows `nan` for the invalid entries. However, please note fit in sage_inference_mode is not recommended in a training session, but you are warned and not blocked, because maybe you know what you do and do it on purpose.\n", + "The SafeInferenceMapper both handles rows that are masked in masked arrays, but also checks rows for nonfinite values and filters these away. Sometimes some descriptors may return a inf or nan, even though the molecule itself is valid. The masking of nonfinite values can be switched off, maybe you are using a model that can handle missing data and only want to filter away invalid molecules.\n", + "\n", + "## Setting safe_inference_mode post-training\n", + "As I said before I believe in catching errors and fixing those during training, but what do we do when we need to switch on safe inference mode for all objects in a pipeline? There's of course a tool for that, so lets demo that:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Without safe inference mode:\n", + "Prediction failed with exception: Invalid input found: [InvalidMol('SmilesToMolTransformer()', error='Invalid Molecule: Explicit valence for atom # 0 N, 4, is greater than permitted'), InvalidMol('SmilesToMolTransformer()', error='Invalid SMILES: I'm not a SMILES')].\n", + "\n", + "With safe inference mode:\n", + "[ 1. 0. 1. 0. nan nan]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + } + ], + "source": [ + "from scikit_mol.safeinference import set_safe_inference_mode\n", + "from sklearn.pipeline import Pipeline\n", + "\n", + "pipe = Pipeline([\n", + " (\"smi2mol\", SmilesToMolTransformer()),\n", + " (\"mfp\", MorganFingerprintTransformer(radius=2, nBits=25)),\n", + " (\"safe_regressor\", SafeInferenceWrapper(LogisticRegression()))\n", + "])\n", + "\n", + "pipe.fit(smiles, [1,0,1,0])\n", + "\n", + "print(\"Without safe inference mode:\")\n", + "try:\n", + " pipe.predict(smiles_with_invalid)\n", + "except Exception as e:\n", + " print(\"Prediction failed with exception: \", e)\n", + "print()\n", + "\n", + "set_safe_inference_mode(pipe, True)\n", + "\n", + "print(\"With safe inference mode:\")\n", + "print(pipe.predict(smiles_with_invalid))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the prediction fail without safe inference mode, and proceeds when it's conveniently set by the `set_safe_inference_mode` utility. The model is now ready for save and reuse in a more failsafe manner :-)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining safe_inference_mode with pandas output\n", + "One potential issue can happen when we combine the safe_inference_mode with Pandas output mode of the transformers. It will work, but depending on the batch something surprising can happen due to the way that Pandas converts masked Numpy arrays. Let me demonstrate the issue, first we predict a batch without any errors." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fp_morgan_1fp_morgan_2fp_morgan_3fp_morgan_4fp_morgan_5fp_morgan_6fp_morgan_7fp_morgan_8fp_morgan_9fp_morgan_10...fp_morgan_16fp_morgan_17fp_morgan_18fp_morgan_19fp_morgan_20fp_morgan_21fp_morgan_22fp_morgan_23fp_morgan_24fp_morgan_25
00000000011...0101110110
10000000111...0100110010
20000000011...0101110000
31000000011...0100110101
\n", + "

4 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " fp_morgan_1 fp_morgan_2 fp_morgan_3 fp_morgan_4 fp_morgan_5 \\\n", + "0 0 0 0 0 0 \n", + "1 0 0 0 0 0 \n", + "2 0 0 0 0 0 \n", + "3 1 0 0 0 0 \n", + "\n", + " fp_morgan_6 fp_morgan_7 fp_morgan_8 fp_morgan_9 fp_morgan_10 ... \\\n", + "0 0 0 0 1 1 ... \n", + "1 0 0 1 1 1 ... \n", + "2 0 0 0 1 1 ... \n", + "3 0 0 0 1 1 ... \n", + "\n", + " fp_morgan_16 fp_morgan_17 fp_morgan_18 fp_morgan_19 fp_morgan_20 \\\n", + "0 0 1 0 1 1 \n", + "1 0 1 0 0 1 \n", + "2 0 1 0 1 1 \n", + "3 0 1 0 0 1 \n", + "\n", + " fp_morgan_21 fp_morgan_22 fp_morgan_23 fp_morgan_24 fp_morgan_25 \n", + "0 1 0 1 1 0 \n", + "1 1 0 0 1 0 \n", + "2 1 0 0 0 0 \n", + "3 1 0 1 0 1 \n", + "\n", + "[4 rows x 25 columns]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mfp.set_output(transform=\"pandas\")\n", + "\n", + "mols = smi2mol.transform(smiles)\n", + "\n", + "fps = mfp.transform(mols)\n", + "fps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then lets see if we transform a batch with an invalid molecule:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fp_morgan_1fp_morgan_2fp_morgan_3fp_morgan_4fp_morgan_5fp_morgan_6fp_morgan_7fp_morgan_8fp_morgan_9fp_morgan_10...fp_morgan_16fp_morgan_17fp_morgan_18fp_morgan_19fp_morgan_20fp_morgan_21fp_morgan_22fp_morgan_23fp_morgan_24fp_morgan_25
00.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.01.01.00.0
10.00.00.00.00.00.00.01.01.01.0...0.01.00.00.01.01.00.00.01.00.0
20.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.00.00.00.0
31.00.00.00.00.00.00.00.01.01.0...0.01.00.00.01.01.00.01.00.01.0
4NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
5NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", + "

6 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " fp_morgan_1 fp_morgan_2 fp_morgan_3 fp_morgan_4 fp_morgan_5 \\\n", + "0 0.0 0.0 0.0 0.0 0.0 \n", + "1 0.0 0.0 0.0 0.0 0.0 \n", + "2 0.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 0.0 0.0 0.0 \n", + "4 NaN NaN NaN NaN NaN \n", + "5 NaN NaN NaN NaN NaN \n", + "\n", + " fp_morgan_6 fp_morgan_7 fp_morgan_8 fp_morgan_9 fp_morgan_10 ... \\\n", + "0 0.0 0.0 0.0 1.0 1.0 ... \n", + "1 0.0 0.0 1.0 1.0 1.0 ... \n", + "2 0.0 0.0 0.0 1.0 1.0 ... \n", + "3 0.0 0.0 0.0 1.0 1.0 ... \n", + "4 NaN NaN NaN NaN NaN ... \n", + "5 NaN NaN NaN NaN NaN ... \n", + "\n", + " fp_morgan_16 fp_morgan_17 fp_morgan_18 fp_morgan_19 fp_morgan_20 \\\n", + "0 0.0 1.0 0.0 1.0 1.0 \n", + "1 0.0 1.0 0.0 0.0 1.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 0.0 1.0 \n", + "4 NaN NaN NaN NaN NaN \n", + "5 NaN NaN NaN NaN NaN \n", + "\n", + " fp_morgan_21 fp_morgan_22 fp_morgan_23 fp_morgan_24 fp_morgan_25 \n", + "0 1.0 0.0 1.0 1.0 0.0 \n", + "1 1.0 0.0 0.0 1.0 0.0 \n", + "2 1.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 1.0 0.0 1.0 \n", + "4 NaN NaN NaN NaN NaN \n", + "5 NaN NaN NaN NaN NaN \n", + "\n", + "[6 rows x 25 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fps = mfp.transform(mols_with_invalid)\n", + "fps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The second output is no longer integers, but floats. As most sklearn models cast input arrays to float32 internally, this difference is likely benign, but that's not guaranteed! Thus if you want to use pandas output for your production models, do check that the final outputs are the same for the valid rows, with and without a single invalid row. Alternatively the dtype for the output of the transformer can be switched to float for consistency." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fp_morgan_1fp_morgan_2fp_morgan_3fp_morgan_4fp_morgan_5fp_morgan_6fp_morgan_7fp_morgan_8fp_morgan_9fp_morgan_10...fp_morgan_16fp_morgan_17fp_morgan_18fp_morgan_19fp_morgan_20fp_morgan_21fp_morgan_22fp_morgan_23fp_morgan_24fp_morgan_25
00.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.01.01.00.0
10.00.00.00.00.00.00.01.01.01.0...0.01.00.00.01.01.00.00.01.00.0
20.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.00.00.00.0
31.00.00.00.00.00.00.00.01.01.0...0.01.00.00.01.01.00.01.00.01.0
\n", + "

4 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " fp_morgan_1 fp_morgan_2 fp_morgan_3 fp_morgan_4 fp_morgan_5 \\\n", + "0 0.0 0.0 0.0 0.0 0.0 \n", + "1 0.0 0.0 0.0 0.0 0.0 \n", + "2 0.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 0.0 0.0 0.0 \n", + "\n", + " fp_morgan_6 fp_morgan_7 fp_morgan_8 fp_morgan_9 fp_morgan_10 ... \\\n", + "0 0.0 0.0 0.0 1.0 1.0 ... \n", + "1 0.0 0.0 1.0 1.0 1.0 ... \n", + "2 0.0 0.0 0.0 1.0 1.0 ... \n", + "3 0.0 0.0 0.0 1.0 1.0 ... \n", + "\n", + " fp_morgan_16 fp_morgan_17 fp_morgan_18 fp_morgan_19 fp_morgan_20 \\\n", + "0 0.0 1.0 0.0 1.0 1.0 \n", + "1 0.0 1.0 0.0 0.0 1.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 0.0 1.0 \n", + "\n", + " fp_morgan_21 fp_morgan_22 fp_morgan_23 fp_morgan_24 fp_morgan_25 \n", + "0 1.0 0.0 1.0 1.0 0.0 \n", + "1 1.0 0.0 0.0 1.0 0.0 \n", + "2 1.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + "[4 rows x 25 columns]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mfp_float = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True, dtype=np.float32)\n", + "mfp_float.set_output(transform=\"pandas\")\n", + "fps = mfp_float.transform(mols)\n", + "fps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I hope this new feature of Scikit-Mol will make it even easier to handle models, even when used in environments without SMILES or molecule validity guarantees." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "vscode", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/11_safe_inference.py b/notebooks/11_safe_inference.py new file mode 100644 index 0000000..83d4d99 --- /dev/null +++ b/notebooks/11_safe_inference.py @@ -0,0 +1,145 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.1 +# kernelspec: +# display_name: vscode +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Safe inference mode +# +# I think everyone which have worked with SMILES and RDKit sooner or later come across a SMILES that doesn't parse. It can happen if the SMILES was produced with a different toolkit that are less strict with e.g. valence rules, or maybe a characher was missing in the copying from the email. During curation of the dataset for training models, these SMILES need to be identfied and eventually fixed or removed. But what happens when we are finished with our modelling? What kind of molecules and SMILES will a user of the model send for the model in the future when it's in deployment. What kind of SMILES will a generative model create that we need to predict? We don't know and we won't know. So it's kind of crucial to be able to handle these situations. Scikit-Learn models usually simply explodes the entire batch that are being predicted. This is where safe_inference_mode was introduced in Scikit-Mol. With the introduction all transformers got a safe inference mode, where they handle invalid input. How they handle it depends a bit on the transformer, so we will go through the different usual steps and see how things have changed with the introduction of the safe inference mode. +# +# NOTE! In the following demonstration I switch on the safe inference mode individually for demonstration purposes. I would not recommend to do that while building and training models, instead I would switch it on _after_ training and evaluation (more on that later). Otherwise there's a risk to train on the 2% of a dataset that didn't fail.... +# +# First some imports and test SMILES and molecules. + +# %% +from rdkit import Chem +from scikit_mol.conversions import SmilesToMolTransformer + +#We have some deprecation warnings, we are adressing them, but they just distract from this demonstration +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) + +smiles = ["C1=CC=C(C=C1)F", "C1=CC=C(C=C1)O", "C1=CC=C(C=C1)N", "C1=CC=C(C=C1)Cl"] +smiles_with_invalid = smiles + ["N(C)(C)(C)C", "I'm not a SMILES"] + +smi2mol = SmilesToMolTransformer(safe_inference_mode=True) + +mols_with_invalid = smi2mol.transform(smiles_with_invalid) +mols_with_invalid + +# %% [markdown] +# Without the safe inference mode, the transformation would simply fail, but now we get the expected array back with our RDKit molecules and a last entry which is an object of the type InvalidMol. InvalidMol is simply a placeholder that tells what step failed the conversion and the error. InvalidMol evaluates to `False` in boolean contexts, so it gets easy to filter away and handle in `if`s and list comprehensions. As example: + +# %% +[mol for mol in mols_with_invalid if mol] + +# %% [markdown] +# or + +# %% +mask = mols_with_invalid.astype(bool) +mols_with_invalid[mask] + +# %% [markdown] +# Having a failsafe SmilesToMol conversion leads us to next step, featurization. The transformers in safe inference mode now return a NumPy masked array instead of a regular NumPy array. It simply evaluates the incoming mols in a boolean context, so e.g. `None`, `np.nan` and other Python objects that evaluates to False will also get masked (i.e. if you use a dataframe with an ROMol column produced with the PandasTools utility) + +# %% +from scikit_mol.fingerprints import MorganFingerprintTransformer + +mfp = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True) +fps = mfp.transform(mols_with_invalid) +fps + + +# %% [markdown] +# However, currently scikit-learn models accepts masked arrays, but they do not respect the mask! So if you fed it directly to the model to train, it would seemingly work, but the invalid samples would all have the fill_value, meaning you could get weird results. Instead we need the last part of the puzzle, the SafeInferenceWrapper class. + +# %% +from scikit_mol.safeinference import SafeInferenceWrapper +from sklearn.linear_model import LogisticRegression +import numpy as np + +regressor = LogisticRegression() +wrapper = SafeInferenceWrapper(regressor, safe_inference_mode=True) +wrapper.fit(fps, [0,1,0,1,0,1]) +wrapper.predict(fps) + + +# %% [markdown] +# + +# %% [markdown] +# The prediction went fine both in fit and in prediction, where the result shows `nan` for the invalid entries. However, please note fit in sage_inference_mode is not recommended in a training session, but you are warned and not blocked, because maybe you know what you do and do it on purpose. +# The SafeInferenceMapper both handles rows that are masked in masked arrays, but also checks rows for nonfinite values and filters these away. Sometimes some descriptors may return a inf or nan, even though the molecule itself is valid. The masking of nonfinite values can be switched off, maybe you are using a model that can handle missing data and only want to filter away invalid molecules. +# +# ## Setting safe_inference_mode post-training +# As I said before I believe in catching errors and fixing those during training, but what do we do when we need to switch on safe inference mode for all objects in a pipeline? There's of course a tool for that, so lets demo that: + +# %% +from scikit_mol.safeinference import set_safe_inference_mode +from sklearn.pipeline import Pipeline + +pipe = Pipeline([ + ("smi2mol", SmilesToMolTransformer()), + ("mfp", MorganFingerprintTransformer(radius=2, nBits=25)), + ("safe_regressor", SafeInferenceWrapper(LogisticRegression())) +]) + +pipe.fit(smiles, [1,0,1,0]) + +print("Without safe inference mode:") +try: + pipe.predict(smiles_with_invalid) +except Exception as e: + print("Prediction failed with exception: ", e) +print() + +set_safe_inference_mode(pipe, True) + +print("With safe inference mode:") +print(pipe.predict(smiles_with_invalid)) + +# %% [markdown] +# We see that the prediction fail without safe inference mode, and proceeds when it's conveniently set by the `set_safe_inference_mode` utility. The model is now ready for save and reuse in a more failsafe manner :-) + +# %% [markdown] +# ## Combining safe_inference_mode with pandas output +# One potential issue can happen when we combine the safe_inference_mode with Pandas output mode of the transformers. It will work, but depending on the batch something surprising can happen due to the way that Pandas converts masked Numpy arrays. Let me demonstrate the issue, first we predict a batch without any errors. + +# %% +mfp.set_output(transform="pandas") + +mols = smi2mol.transform(smiles) + +fps = mfp.transform(mols) +fps + +# %% [markdown] +# Then lets see if we transform a batch with an invalid molecule: + +# %% +fps = mfp.transform(mols_with_invalid) +fps + +# %% [markdown] +# The second output is no longer integers, but floats. As most sklearn models cast input arrays to float32 internally, this difference is likely benign, but that's not guaranteed! Thus if you want to use pandas output for your production models, do check that the final outputs are the same for the valid rows, with and without a single invalid row. Alternatively the dtype for the output of the transformer can be switched to float for consistency. + +# %% +mfp_float = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True, dtype=np.float32) +mfp_float.set_output(transform="pandas") +fps = mfp_float.transform(mols) +fps + +# %% [markdown] +# I hope this new feature of Scikit-Mol will make it even easier to handle models, even when used in environments without SMILES or molecule validity guarantees. diff --git a/notebooks/README.md b/notebooks/README.md index 8b0ec12..b744709 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -14,3 +14,4 @@ This is a collection of notebooks in the notebooks directory which demonstrates - [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb) - [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb) - [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb) +- [Working with pipelines and estimators in safe inference mode](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb) From f17fd43f799e287a693b66fd95d2e6762cbf20df Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Sun, 13 Oct 2024 17:11:59 +0200 Subject: [PATCH 40/41] Some cleanup --- scikit_mol/safeinference.py | 76 ------------------------------------- 1 file changed, 76 deletions(-) diff --git a/scikit_mol/safeinference.py b/scikit_mol/safeinference.py index 59b922c..401af4a 100644 --- a/scikit_mol/safeinference.py +++ b/scikit_mol/safeinference.py @@ -97,82 +97,6 @@ def wrapper(obj, X, y=None, *args, **kwargs): return decorator -# class NanGuardWrapper(BaseEstimator, TransformerMixin): -# """Nan/Inf safe wrapper for sklearn estimator objects.""" - -# def __init__( -# self, -# estimator: BaseEstimator, -# handle_errors: bool = False, -# replace_value=np.nan, -# mask_nonfinite: bool = True, -# ): -# super().__init__() -# self.handle_errors = handle_errors -# self.replace_value = replace_value -# self.estimator = estimator -# self.mask_nonfinite = mask_nonfinite - -# def has_predict(self) -> bool: -# return hasattr(self.estimator, "predict") - -# def has_predict_proba(self) -> bool: -# return hasattr(self.estimator, "predict_proba") - -# def has_transform(self) -> bool: -# return hasattr(self.estimator, "transform") - -# def has_fit_transform(self) -> bool: -# return hasattr(self.estimator, "fit_transform") - -# def has_score(self) -> bool: -# return hasattr(self.estimator, "score") - -# def has_n_features_in_(self) -> bool: -# return hasattr(self.estimator, "n_features_in_") - -# def has_decision_function(self) -> bool: -# return hasattr(self.estimator, "decision_function") - -# @property -# def n_features_in_(self) -> int: -# return self.estimator.n_features_in_ - -# @filter_invalid_rows(warn_on_invalid=True) -# def fit(self, X, *args, **fit_params) -> Any: -# return self.estimator.fit(X, *args, **fit_params) - -# @available_if(has_predict) -# @filter_invalid_rows() -# def predict(self, X): -# return self.estimator.predict(X) - -# @available_if(has_decision_function) -# @filter_invalid_rows() -# def decision_function(self, X): -# return self.estimator.decision_function(X) - -# @available_if(has_predict_proba) -# @filter_invalid_rows() -# def predict_proba(self, X): -# return self.estimator.predict_proba(X) - -# @available_if(has_transform) -# @filter_invalid_rows() -# def transform(self, X): -# return self.estimator.transform(X) - -# @available_if(has_fit_transform) -# @filter_invalid_rows(warn_on_invalid=True) -# def fit_transform(self, X, y): -# return self.estimator.fit_transform(X, y) - -# @available_if(has_score) -# @filter_invalid_rows(warn_on_invalid=True) -# def score(self, X, y): -# return self.estimator.score(X, y) - - class SafeInferenceWrapper(BaseEstimator, TransformerMixin): """ Wrapper for sklearn estimators to ensure safe inference in production environments. From c4e4c696d82cdf47e43d033160ff890aff74d906 Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Sun, 13 Oct 2024 17:22:13 +0200 Subject: [PATCH 41/41] Fixed spurious import --- scikit_mol/conversions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index da206f5..450ab31 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -6,7 +6,6 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin -from torch import Block from scikit_mol.core import ( check_transform_input,