From e8744d280eb08d7b91f6dc56336e6fc3eb1e5739 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 22 Sep 2023 11:06:23 +0200 Subject: [PATCH 01/13] Add InvalidInstance and add empty file for wrappers --- scikit_mol/utilities.py | 7 +++++++ scikit_mol/wrapper.py | 0 2 files changed, 7 insertions(+) create mode 100644 scikit_mol/wrapper.py diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 70eac51..24356d4 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -1,7 +1,14 @@ #For a non-scikit-learn check smiles sanitizer class +from typing import NamedTuple import pandas as pd from rdkit import Chem + +class InvalidInstance(NamedTuple): + pipeline_step: str + error: str + + class CheckSmilesSanitazion: def __init__(self, return_mol=False): self.return_mol = return_mol diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py new file mode 100644 index 0000000..e69de29 From 361284bde859a73eed7ab6548922a79e84834d04 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 22 Sep 2023 16:15:45 +0200 Subject: [PATCH 02/13] Nothing works but I want to save commit --- scikit_mol/_invalid.py | 118 +++++++++++++++++++++++++++++++++ scikit_mol/conversions.py | 5 +- scikit_mol/fingerprints.py | 15 +++-- scikit_mol/utilities.py | 7 +- scikit_mol/wrapper.py | 7 ++ tests/fixtures.py | 2 +- tests/test_invalid_handling.py | 25 +++++++ 7 files changed, 164 insertions(+), 15 deletions(-) create mode 100644 scikit_mol/_invalid.py create mode 100644 tests/test_invalid_handling.py diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py new file mode 100644 index 0000000..900dc71 --- /dev/null +++ b/scikit_mol/_invalid.py @@ -0,0 +1,118 @@ +from abc import ABC +from typing import Any, Callable, NamedTuple, Sequence, TypeVar + +import numpy as np +import numpy.typing as npt + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class InvalidInstance(NamedTuple): + pipeline_step: str + error: str + + +class ArrayWithInvalidInstances: + invalid_list: list[InvalidInstance] + + def __init__(self, array_list: list[npt.NDArray[np.int8] | InvalidInstance]): + self.is_valid_array = get_is_valid_array(array_list) + valid_vector_list = filter_by_list(array_list, self.is_valid_array) + self.matrix = np.vstack(valid_vector_list) + self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) + + def __getitem__(self, item: int) -> npt.NDArray[np.int8] | InvalidInstance: + n_invalids_prior = sum(~self.is_valid_array[:item - 1]) + if self.is_valid_array[item]: + return self.matrix[item - n_invalids_prior] + return self.invalid_list[n_invalids_prior + 1] + + def __setitem__(self, key: int, value: npt.NDArray[np.int8] | InvalidInstance) -> None: + n_invalids_prior = sum(~self.is_valid_array[:key - 1]) + if isinstance(value, InvalidInstance): + if self.is_valid_array[key]: + self.matrix = np.delete(self.matrix, key - n_invalids_prior) + self.is_valid_array[key] = False + self.invalid_list.insert(n_invalids_prior + 1, value) + else: + self.invalid_list[n_invalids_prior + 1] = value + else: + if self.is_valid_array[key]: + self.matrix[key - n_invalids_prior] = value + else: + self.matrix = np.insert(self.matrix, key-n_invalids_prior, value) + del(self.invalid_list[n_invalids_prior + 1]) + self.is_valid_array[key] = True + + +def update_list_by( + old_list: list[npt.NDArray[np.int8] | InvalidInstance] | ArrayWithInvalidInstances, + new_values: list[Any], + value_indices: npt.NDArray[np.int_], + ): + old_list = list(old_list) + for new_value, idx in zip(new_values, value_indices, strict=True): + old_list[idx] = new_value + return old_list + + +def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]): + if isinstance(item_list, np.ndarray): + return item_list[is_valid_array] + + item_list_new = [] + for item, is_valid in zip(item_list, is_valid_array): + if is_valid: + item_list_new.append(item) + return item_list_new + +# Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]] +# ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]] +def rdkit_error_handling(func): + def wrapper(obj, *args, **kwargs): + x = args[0] + if isinstance(x, ArrayWithInvalidInstances): + is_valid_array = x.is_valid_array + x_sub = x.matrix + else: + is_valid_array = get_is_valid_array(x) + x_sub = filter_by_list(x, is_valid_array) + if len(args) > 1: + y = args[1] + y_sub = filter_by_list(y, is_valid_array) + else: + y_sub = None + x_new = func(obj, x_sub, y_sub, **kwargs) + new_pos = np.where(is_valid_array)[0] + if isinstance(x, (list, ArrayWithInvalidInstances)): + x_list = update_list_by(x, x_new, new_pos) + else: + x_array = np.array(x) + x_array[is_valid_array] = x_new + x_list = list(x_array) + if isinstance(x_new, ArrayWithInvalidInstances): + return ArrayWithInvalidInstances(x_list) + return x_list + return wrapper + + +def filter_rows( + X: Sequence[_T], y: Sequence[_U] +) -> tuple[Sequence[_T], Sequence[_U]]: + is_valid_array = get_is_valid_array(X) + x_new = filter_by_list(X, is_valid_array) + y_new = filter_by_list(y, is_valid_array) + return x_new, y_new + + +def get_is_valid_array(item_list: Sequence[Any]) -> npt.NDArray[np.bool_]: + is_valid_list = [] + for i, item in enumerate(item_list): + if not isinstance(item, InvalidInstance): + is_valid_list.append(True) + else: + is_valid_list.append(False) + return np.array(is_valid_list, dtype=bool) + + diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index c9d668a..914da37 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -6,6 +6,8 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin +from scikit_mol._invalid import InvalidInstance + class SmilesToMolTransformer(BaseEstimator, TransformerMixin): @@ -56,8 +58,7 @@ def _transform(self, X_smiles_list): if mol: X_out.append(mol) else: - raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first') - + X_out.append(InvalidInstance(str(self), "Invalid Smiles.")) return X_out def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.? diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 81bc43b..36bfdc0 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -16,9 +16,11 @@ from scipy.sparse import vstack from sklearn.base import BaseEstimator, TransformerMixin +from scikit_mol._invalid import ArrayWithInvalidInstances, rdkit_error_handling from abc import ABC, abstractmethod + #%% class FpsTransformer(ABC, BaseEstimator, TransformerMixin): @@ -49,10 +51,10 @@ def fit(self, X, y=None): return self def _transform(self, X): - arr = np.zeros((len(X), self.nBits), dtype=np.int8) + arr_list = [] for i, mol in enumerate(X): - arr[i,:] = self._transform_mol(mol) - return arr + arr_list.append(self._transform_mol(mol)) + return ArrayWithInvalidInstances(arr_list) def _transform_sparse(self, X): arr = np.zeros((len(X), self.nBits), dtype=np.int8) @@ -61,6 +63,7 @@ def _transform_sparse(self, X): return lil_matrix(arr) + @rdkit_error_handling def transform(self, X, y=None): """Transform a list of RDKit molecule objects into a fingerprint array @@ -89,9 +92,9 @@ def transform(self, X, y=None): #arrays = pool.map(self._transform, x_chunks) parameters = self.get_params() arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) - - arr = np.concatenate(arrays) - return arr + arr_list = [] + arr_list.extend(arrays) + return arr_list class MACCSKeysFingerprintTransformer(FpsTransformer): diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 24356d4..866a9aa 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -1,14 +1,9 @@ #For a non-scikit-learn check smiles sanitizer class -from typing import NamedTuple + import pandas as pd from rdkit import Chem -class InvalidInstance(NamedTuple): - pipeline_step: str - error: str - - class CheckSmilesSanitazion: def __init__(self, return_mol=False): self.return_mol = return_mol diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index e69de29..56efe9d 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -0,0 +1,7 @@ +from abc import ABC +from sklearn.base import BaseEstimator + + +class AbstractWrapper(BaseEstimator, ABC): + pass + diff --git a/tests/fixtures.py b/tests/fixtures.py index 4102825..f6aba9e 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -8,7 +8,7 @@ @pytest.fixture def smiles_list(): - return [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in ['O=C(O)c1ccccc1', + return [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in ['O=C(O)c1ccccc1', 'O=C([O-])c1ccccc1', 'O=C([O-])c1ccccc1.[Na+]', 'O=C(O[Na])c1ccccc1', diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py new file mode 100644 index 0000000..fb95060 --- /dev/null +++ b/tests/test_invalid_handling.py @@ -0,0 +1,25 @@ +import pytest +from scikit_mol.conversions import SmilesToMolTransformer +from scikit_mol.fingerprints import MorganFingerprintTransformer +from sklearn.pipeline import Pipeline +from fixtures import smiles_list, invalid_smiles_list +from scikit_mol._invalid import ArrayWithInvalidInstances + +@pytest.fixture +def smilestofp_pipeline(): + pipeline = Pipeline( + [ + ("smiles_to_mol", SmilesToMolTransformer()), + ("mol_2_fp", MorganFingerprintTransformer()), + ] + + ) + return pipeline + + +def test_descriptor_transformer(invalid_smiles_list, smilestofp_pipeline): + smilestofp_pipeline.set_params() + mol_list: ArrayWithInvalidInstances = smilestofp_pipeline.transform(invalid_smiles_list) + print(mol_list.is_valid_array) + print(mol_list.matrix) + print(mol_list.invalid_list) From e5c6a204d5dabd081910cc103ff543f0c86fe2f5 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 22 Sep 2023 17:43:29 +0200 Subject: [PATCH 03/13] first working draft --- scikit_mol/_invalid.py | 61 ++++++++++++++++++++++------------ scikit_mol/fingerprints.py | 4 +-- scikit_mol/wrapper.py | 57 ++++++++++++++++++++++++++++++- tests/test_invalid_handling.py | 16 +++++---- 4 files changed, 107 insertions(+), 31 deletions(-) diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py index 900dc71..d87a460 100644 --- a/scikit_mol/_invalid.py +++ b/scikit_mol/_invalid.py @@ -13,45 +13,52 @@ class InvalidInstance(NamedTuple): error: str -class ArrayWithInvalidInstances: +class NumpyArrayWithInvalidInstances: + is_valid_array: npt.NDArray[np.bool_] invalid_list: list[InvalidInstance] + value_array: npt.NDArray[Any] - def __init__(self, array_list: list[npt.NDArray[np.int8] | InvalidInstance]): + def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]): self.is_valid_array = get_is_valid_array(array_list) valid_vector_list = filter_by_list(array_list, self.is_valid_array) - self.matrix = np.vstack(valid_vector_list) + self.value_array = np.vstack(valid_vector_list) self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) - def __getitem__(self, item: int) -> npt.NDArray[np.int8] | InvalidInstance: + def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance: n_invalids_prior = sum(~self.is_valid_array[:item - 1]) if self.is_valid_array[item]: - return self.matrix[item - n_invalids_prior] + return self.value_array[item - n_invalids_prior] return self.invalid_list[n_invalids_prior + 1] - def __setitem__(self, key: int, value: npt.NDArray[np.int8] | InvalidInstance) -> None: + def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> None: n_invalids_prior = sum(~self.is_valid_array[:key - 1]) if isinstance(value, InvalidInstance): if self.is_valid_array[key]: - self.matrix = np.delete(self.matrix, key - n_invalids_prior) + self.value_array = np.delete(self.value_array, key - n_invalids_prior) self.is_valid_array[key] = False self.invalid_list.insert(n_invalids_prior + 1, value) else: self.invalid_list[n_invalids_prior + 1] = value else: if self.is_valid_array[key]: - self.matrix[key - n_invalids_prior] = value + self.value_array[key - n_invalids_prior] = value else: - self.matrix = np.insert(self.matrix, key-n_invalids_prior, value) + self.value_array = np.insert(self.value_array, key - n_invalids_prior, value) del(self.invalid_list[n_invalids_prior + 1]) self.is_valid_array[key] = True + def array_filled_with(self, fill_value) -> npt.NDArray[Any]: + out = np.full((len(self.is_valid_array), self.value_array.shape[1]), fill_value) + out[self.is_valid_array] = self.value_array + return out -def update_list_by( - old_list: list[npt.NDArray[np.int8] | InvalidInstance] | ArrayWithInvalidInstances, + +def batch_update_sequence( + old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances, new_values: list[Any], value_indices: npt.NDArray[np.int_], ): - old_list = list(old_list) + old_list = list(old_list) # Make shallow copy of list to avoid inplace changes. for new_value, idx in zip(new_values, value_indices, strict=True): old_list[idx] = new_value return old_list @@ -67,32 +74,44 @@ def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]): item_list_new.append(item) return item_list_new + # Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]] # ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]] def rdkit_error_handling(func): def wrapper(obj, *args, **kwargs): x = args[0] - if isinstance(x, ArrayWithInvalidInstances): + if isinstance(x, NumpyArrayWithInvalidInstances): is_valid_array = x.is_valid_array - x_sub = x.matrix + x_sub = x.value_array else: is_valid_array = get_is_valid_array(x) x_sub = filter_by_list(x, is_valid_array) if len(args) > 1: y = args[1] - y_sub = filter_by_list(y, is_valid_array) + if y is not None: + y_sub = filter_by_list(y, is_valid_array) + else: + y_sub = None + x_new = func(obj, x_sub, y_sub, **kwargs) else: - y_sub = None - x_new = func(obj, x_sub, y_sub, **kwargs) + x_new = func(obj, x_sub, **kwargs) + + if x_new is None: # fit may not return anything + return None new_pos = np.where(is_valid_array)[0] - if isinstance(x, (list, ArrayWithInvalidInstances)): - x_list = update_list_by(x, x_new, new_pos) + if isinstance(x_new, np.ndarray) and isinstance(x, NumpyArrayWithInvalidInstances): + if x_new.shape[0] != x.value_array.shape[0]: + raise AssertionError("Numer of rows is not as expected.") + x.value_array = x_new + return x + if isinstance(x, (list, NumpyArrayWithInvalidInstances)): + x_list = batch_update_sequence(x, x_new, new_pos) else: x_array = np.array(x) x_array[is_valid_array] = x_new x_list = list(x_array) - if isinstance(x_new, ArrayWithInvalidInstances): - return ArrayWithInvalidInstances(x_list) + if isinstance(x_new, NumpyArrayWithInvalidInstances): + return NumpyArrayWithInvalidInstances(x_list) return x_list return wrapper diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 36bfdc0..2998136 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -16,7 +16,7 @@ from scipy.sparse import vstack from sklearn.base import BaseEstimator, TransformerMixin -from scikit_mol._invalid import ArrayWithInvalidInstances, rdkit_error_handling +from scikit_mol._invalid import NumpyArrayWithInvalidInstances, rdkit_error_handling from abc import ABC, abstractmethod @@ -54,7 +54,7 @@ def _transform(self, X): arr_list = [] for i, mol in enumerate(X): arr_list.append(self._transform_mol(mol)) - return ArrayWithInvalidInstances(arr_list) + return NumpyArrayWithInvalidInstances(arr_list) def _transform_sparse(self, X): arr = np.zeros((len(X), self.nBits), dtype=np.int8) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 56efe9d..17787af 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -1,7 +1,62 @@ from abc import ABC +from typing import Any + +import numpy as np from sklearn.base import BaseEstimator +from sklearn.pipeline import Pipeline +from sklearn.utils.metaestimators import available_if + +from scikit_mol._invalid import rdkit_error_handling, InvalidInstance, NumpyArrayWithInvalidInstances class AbstractWrapper(BaseEstimator, ABC): - pass + model: BaseEstimator | Pipeline + + def __init__(self, replace_invalid: bool, replace_value=np.nan): + self.replace_invalid = replace_invalid + self.replace_value = replace_value + + @rdkit_error_handling + def fit(self, X, y, **fit_params) -> Any: + return self.model.fit(X, y, **fit_params) + + def has_predict(self) -> bool: + return hasattr(self.model, "predict") + + def has_fit_predict(self) -> bool: + return hasattr(self.model, "fit_predict") + + +class WrappedTransformer(AbstractWrapper): + def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan): + super().__init__(replace_invalid=replace_invalid, replace_value=replace_value) + self.model = model + + def has_transform(self) -> bool: + return hasattr(self.model, "transform") + + def has_fit_transform(self) -> bool: + return hasattr(self.model, "fit_transform") + + @available_if(has_transform) + @rdkit_error_handling + def transform(self, X): + return self.model.transform(X) + + @rdkit_error_handling + def _fit_transform(self, X, y): + return self.model.fit_transform(X, y) + @available_if(has_fit_transform) + def fit_transform(self, X, y=None): + out = self._fit_transform(X,y) + if not self.replace_invalid: + return out + + if isinstance(out, NumpyArrayWithInvalidInstances): + return out.array_filled_with(self.replace_value) + + if isinstance(out, list): + return [self.replace_value if isinstance(v, InvalidInstance) else v for v in out] + + diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index fb95060..42e0344 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -1,9 +1,12 @@ import pytest -from scikit_mol.conversions import SmilesToMolTransformer -from scikit_mol.fingerprints import MorganFingerprintTransformer +from sklearn.decomposition import PCA from sklearn.pipeline import Pipeline + from fixtures import smiles_list, invalid_smiles_list -from scikit_mol._invalid import ArrayWithInvalidInstances +from scikit_mol.conversions import SmilesToMolTransformer +from scikit_mol.fingerprints import MorganFingerprintTransformer +from scikit_mol.wrapper import WrappedTransformer +from scikit_mol._invalid import NumpyArrayWithInvalidInstances @pytest.fixture def smilestofp_pipeline(): @@ -11,6 +14,7 @@ def smilestofp_pipeline(): [ ("smiles_to_mol", SmilesToMolTransformer()), ("mol_2_fp", MorganFingerprintTransformer()), + ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)) ] ) @@ -19,7 +23,5 @@ def smilestofp_pipeline(): def test_descriptor_transformer(invalid_smiles_list, smilestofp_pipeline): smilestofp_pipeline.set_params() - mol_list: ArrayWithInvalidInstances = smilestofp_pipeline.transform(invalid_smiles_list) - print(mol_list.is_valid_array) - print(mol_list.matrix) - print(mol_list.invalid_list) + mol_list: NumpyArrayWithInvalidInstances = smilestofp_pipeline.fit_transform(invalid_smiles_list) + print(mol_list) From 2593ca212a647c280da65fb1c67a38daa4e3b5b8 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 10:41:44 +0200 Subject: [PATCH 04/13] Add docstrings --- scikit_mol/wrapper.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py index 17787af..1899e27 100644 --- a/scikit_mol/wrapper.py +++ b/scikit_mol/wrapper.py @@ -1,3 +1,5 @@ +"""Wrapper for sklearn estimators and pipelines to handle errors.""" + from abc import ABC from typing import Any @@ -10,9 +12,26 @@ class AbstractWrapper(BaseEstimator, ABC): + """ + Abstract class for the wrapper of sklearn objects. + + Attributes + ---------- + model: BaseEstimator | Pipeline + The wrapped model or pipeline. + """ model: BaseEstimator | Pipeline - def __init__(self, replace_invalid: bool, replace_value=np.nan): + def __init__(self, replace_invalid: bool, replace_value: Any = np.nan): + """Initialize the AbstractWrapper. + + Parameters + ---------- + replace_invalid: bool + Whether to replace or remove errors + replace_value: Any, default=np.nan + If replace_invalid==True, insert this value on the erroneous instance. + """ self.replace_invalid = replace_invalid self.replace_value = replace_value @@ -28,7 +47,20 @@ def has_fit_predict(self) -> bool: class WrappedTransformer(AbstractWrapper): + """Wrapper for sklearn transformer objects.""" + def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan): + """Initialize the WrappedTransformer. + + Parameters + ---------- + model: BaseEstimator + Wrapped model to be protected against Errors. + replace_invalid: bool + Whether to replace or remove errors + replace_value: Any, default=np.nan + If replace_invalid==True, insert this value on the erroneous instance. + """ super().__init__(replace_invalid=replace_invalid, replace_value=replace_value) self.model = model @@ -46,6 +78,7 @@ def transform(self, X): @rdkit_error_handling def _fit_transform(self, X, y): return self.model.fit_transform(X, y) + @available_if(has_fit_transform) def fit_transform(self, X, y=None): out = self._fit_transform(X,y) From de5d0d892731323f4ea65eae8d2aa3675f37b291 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:03:18 +0200 Subject: [PATCH 05/13] refactor unittest --- tests/test_invalid_handling.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index 42e0344..9ebaafc 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from sklearn.decomposition import PCA from sklearn.pipeline import Pipeline @@ -7,12 +8,15 @@ from scikit_mol.fingerprints import MorganFingerprintTransformer from scikit_mol.wrapper import WrappedTransformer from scikit_mol._invalid import NumpyArrayWithInvalidInstances +from tests.test_invalid_helpers.invalid_transformer import TestInvalidTransformer + @pytest.fixture def smilestofp_pipeline(): pipeline = Pipeline( [ ("smiles_to_mol", SmilesToMolTransformer()), + ("remove_sulfur", TestInvalidTransformer()), ("mol_2_fp", MorganFingerprintTransformer()), ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)) ] @@ -21,7 +25,20 @@ def smilestofp_pipeline(): return pipeline -def test_descriptor_transformer(invalid_smiles_list, smilestofp_pipeline): +def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): smilestofp_pipeline.set_params() - mol_list: NumpyArrayWithInvalidInstances = smilestofp_pipeline.fit_transform(invalid_smiles_list) - print(mol_list) + mol_pca = smilestofp_pipeline.fit_transform(smiles_list) + error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) + + if mol_pca.shape != (len(smiles_list), 2): + raise ValueError("The PCA does not return the proper dimensions.") + if isinstance(error_mol_pca, NumpyArrayWithInvalidInstances): + raise TypeError("The Errors were not properly remove from the output array.") + + expected_nans = np.array([[0, 0, 1, 1], [0, 1, 0, 1]]) + if not np.all(np.equal(expected_nans, np.where(np.isnan(error_mol_pca)))): + raise ValueError("Errors were replaced on the wrong positions.") + + non_nan_rows = ~np.any(np.isnan(error_mol_pca), axis=1) + if not np.all(np.isclose(mol_pca, error_mol_pca[non_nan_rows, :])): + raise ValueError("Removing errors introduces changes in the PCA output.") From 7ddbfe122153c639b9aed247552552a33d8bb084 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:03:44 +0200 Subject: [PATCH 06/13] add Transformer which can make molecules invalid --- .../invalid_transformer.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 tests/test_invalid_helpers/invalid_transformer.py diff --git a/tests/test_invalid_helpers/invalid_transformer.py b/tests/test_invalid_helpers/invalid_transformer.py new file mode 100644 index 0000000..ffe2738 --- /dev/null +++ b/tests/test_invalid_helpers/invalid_transformer.py @@ -0,0 +1,44 @@ +from typing import Optional, Sequence +from sklearn.base import BaseEstimator, TransformerMixin +from rdkit import Chem + +from scikit_mol._invalid import ( + InvalidInstance, + rdkit_error_handling, +) + + +class TestInvalidTransformer(BaseEstimator, TransformerMixin): + """This class is ment for tesing purposes only. + + All molecules with element number are returned as invalid instance. + + Attributes + --------- + atomic_number_set: set[int] + Atomic numbers which upon occurrence in the molecule make it invalid. + """ + + atomic_number_set: set[int] + + def __init__(self, atomic_number_set: Sequence[int] | None = None) -> None: + if atomic_number_set is None: + atomic_number_set = {16} + self.atomic_number_set = set(atomic_number_set) + + def _transform_mol(self, mol: Chem.Mol) -> Chem.Mol | InvalidInstance: + unique_elements = {atom.GetAtomicNum() for atom in mol.GetAtoms()} + forbidden_elements = self.atomic_number_set & unique_elements + if forbidden_elements: + return InvalidInstance(str(self), f"Molecule contains {forbidden_elements}") + return mol + + @rdkit_error_handling + def transform(self, X: list[Chem.Mol]) -> list[Chem.Mol | InvalidInstance]: + return [self._transform_mol(mol) for mol in X] + + def fit(self, X, y, fit_params): + pass + + def fit_transform(self, X, y=None, **fit_params): + return self.transform(X) From 17f19d4428d028bf078ec78041214cb9750faa52 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:04:10 +0200 Subject: [PATCH 07/13] Resolve merge issues with smiles_list --- tests/fixtures.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index 8cf0751..2852068 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -63,8 +63,9 @@ def chiral_smiles_list(): #Need to be a certain size, so the fingerprints reacts 'N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]']] @pytest.fixture -def invalid_smiles_list(smiles_list): - smiles_list.append('Invalid') +def invalid_smiles_list(): + smiles_list = ['S-CC', 'Invalid'] + smiles_list.extend(_SMILES_LIST) return smiles_list _MOLS_LIST = [Chem.MolFromSmiles(smiles) for smiles in _SMILES_LIST] From 6c1d3e68e16e77ac19414e28a5fa815b14d3c72f Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:04:21 +0200 Subject: [PATCH 08/13] Add docstring --- scikit_mol/_invalid.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py index d87a460..f850adf 100644 --- a/scikit_mol/_invalid.py +++ b/scikit_mol/_invalid.py @@ -9,11 +9,17 @@ class InvalidInstance(NamedTuple): + """ + The InvalidInstance represents objects which raised an error during a pipeline step. + """ pipeline_step: str error: str class NumpyArrayWithInvalidInstances: + """ + The NumpyArrayWithInvalidInstances is + """ is_valid_array: npt.NDArray[np.bool_] invalid_list: list[InvalidInstance] value_array: npt.NDArray[Any] From 591a2ac7afde7c4670a05ec901c1e616d9f2164e Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:05:24 +0200 Subject: [PATCH 09/13] Add init --- tests/test_invalid_helpers/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/test_invalid_helpers/__init__.py diff --git a/tests/test_invalid_helpers/__init__.py b/tests/test_invalid_helpers/__init__.py new file mode 100644 index 0000000..d0c583a --- /dev/null +++ b/tests/test_invalid_helpers/__init__.py @@ -0,0 +1 @@ +"""Initialize module for helper classes and functions used to test the handling of invalid inputs.""" From c713ece6809d4736b7f49980672eae6349447517 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:10:53 +0200 Subject: [PATCH 10/13] Add Message encountering Errors --- scikit_mol/conversions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index db89661..2399eba 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -64,7 +64,8 @@ def _transform(self, X): if mol: X_out.append(mol) else: - X_out.append(InvalidInstance(str(self), "Invalid Smiles.")) + error = Chem.DetectChemistryProblems(mol) + X_out.append(InvalidInstance(str(self), f"Invalid Smiles.: {error}")) return X_out @check_transform_input From 85c9745b83d2e4335d754a2be2cd0b108924a59e Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:54:33 +0200 Subject: [PATCH 11/13] Fix Message encountering Errors --- .idea/.gitignore | 3 +++ .idea/inspectionProfiles/Project_Default.xml | 6 ++++++ .idea/inspectionProfiles/profiles_settings.xml | 6 ++++++ .idea/misc.xml | 10 ++++++++++ .idea/scikit-mol.iml | 13 +++++++++++++ .idea/vcs.xml | 6 ++++++ scikit_mol/conversions.py | 10 ++++++++-- 7 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/scikit-mol.iml create mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..9aa2337 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..13c8595 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/scikit-mol.iml b/.idea/scikit-mol.iml new file mode 100644 index 0000000..fb6d745 --- /dev/null +++ b/.idea/scikit-mol.iml @@ -0,0 +1,13 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 2399eba..b6537e7 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -64,8 +64,14 @@ def _transform(self, X): if mol: X_out.append(mol) else: - error = Chem.DetectChemistryProblems(mol) - X_out.append(InvalidInstance(str(self), f"Invalid Smiles.: {error}")) + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol: + errors = Chem.DetectChemistryProblems(mol) + error_message = "\n".join(error.Message() for error in errors) + message = f"Invalid SMILES: {error_message}" + else: + message = f"Invalid SMILES: {smiles}" + X_out.append(InvalidInstance(str(self), message)) return X_out @check_transform_input From ce55a52f131d87d49ce884a5cb0e37b82bd2442e Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:54:50 +0200 Subject: [PATCH 12/13] Fix reference to test classes --- tests/test_invalid_handling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py index 9ebaafc..b848659 100644 --- a/tests/test_invalid_handling.py +++ b/tests/test_invalid_handling.py @@ -8,7 +8,7 @@ from scikit_mol.fingerprints import MorganFingerprintTransformer from scikit_mol.wrapper import WrappedTransformer from scikit_mol._invalid import NumpyArrayWithInvalidInstances -from tests.test_invalid_helpers.invalid_transformer import TestInvalidTransformer +from test_invalid_helpers.invalid_transformer import TestInvalidTransformer @pytest.fixture @@ -26,6 +26,7 @@ def smilestofp_pipeline(): def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): + smilestofp_pipeline.set_params() mol_pca = smilestofp_pipeline.fit_transform(smiles_list) error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) From 9e72d676134049b6de53d308f0880586de3fa6d9 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" Date: Fri, 13 Sep 2024 11:55:14 +0200 Subject: [PATCH 13/13] Add __len__ to class --- scikit_mol/_invalid.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py index f850adf..e1a91e5 100644 --- a/scikit_mol/_invalid.py +++ b/scikit_mol/_invalid.py @@ -30,6 +30,9 @@ def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]): self.value_array = np.vstack(valid_vector_list) self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) + def __len__(self): + return self.is_valid_array.shape[0] + def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance: n_invalids_prior = sum(~self.is_valid_array[:item - 1]) if self.is_valid_array[item]: