diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..9aa2337 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..13c8595 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/scikit-mol.iml b/.idea/scikit-mol.iml new file mode 100644 index 0000000..fb6d745 --- /dev/null +++ b/.idea/scikit-mol.iml @@ -0,0 +1,13 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py new file mode 100644 index 0000000..e1a91e5 --- /dev/null +++ b/scikit_mol/_invalid.py @@ -0,0 +1,146 @@ +from abc import ABC +from typing import Any, Callable, NamedTuple, Sequence, TypeVar + +import numpy as np +import numpy.typing as npt + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class InvalidInstance(NamedTuple): + """ + The InvalidInstance represents objects which raised an error during a pipeline step. + """ + pipeline_step: str + error: str + + +class NumpyArrayWithInvalidInstances: + """ + The NumpyArrayWithInvalidInstances is + """ + is_valid_array: npt.NDArray[np.bool_] + invalid_list: list[InvalidInstance] + value_array: npt.NDArray[Any] + + def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]): + self.is_valid_array = get_is_valid_array(array_list) + valid_vector_list = filter_by_list(array_list, self.is_valid_array) + self.value_array = np.vstack(valid_vector_list) + self.invalid_list = filter_by_list(array_list, ~self.is_valid_array) + + def __len__(self): + return self.is_valid_array.shape[0] + + def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance: + n_invalids_prior = sum(~self.is_valid_array[:item - 1]) + if self.is_valid_array[item]: + return self.value_array[item - n_invalids_prior] + return self.invalid_list[n_invalids_prior + 1] + + def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> None: + n_invalids_prior = sum(~self.is_valid_array[:key - 1]) + if isinstance(value, InvalidInstance): + if self.is_valid_array[key]: + self.value_array = np.delete(self.value_array, key - n_invalids_prior) + self.is_valid_array[key] = False + self.invalid_list.insert(n_invalids_prior + 1, value) + else: + self.invalid_list[n_invalids_prior + 1] = value + else: + if self.is_valid_array[key]: + self.value_array[key - n_invalids_prior] = value + else: + self.value_array = np.insert(self.value_array, key - n_invalids_prior, value) + del(self.invalid_list[n_invalids_prior + 1]) + self.is_valid_array[key] = True + + def array_filled_with(self, fill_value) -> npt.NDArray[Any]: + out = np.full((len(self.is_valid_array), self.value_array.shape[1]), fill_value) + out[self.is_valid_array] = self.value_array + return out + + +def batch_update_sequence( + old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances, + new_values: list[Any], + value_indices: npt.NDArray[np.int_], + ): + old_list = list(old_list) # Make shallow copy of list to avoid inplace changes. + for new_value, idx in zip(new_values, value_indices, strict=True): + old_list[idx] = new_value + return old_list + + +def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]): + if isinstance(item_list, np.ndarray): + return item_list[is_valid_array] + + item_list_new = [] + for item, is_valid in zip(item_list, is_valid_array): + if is_valid: + item_list_new.append(item) + return item_list_new + + +# Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]] +# ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]] +def rdkit_error_handling(func): + def wrapper(obj, *args, **kwargs): + x = args[0] + if isinstance(x, NumpyArrayWithInvalidInstances): + is_valid_array = x.is_valid_array + x_sub = x.value_array + else: + is_valid_array = get_is_valid_array(x) + x_sub = filter_by_list(x, is_valid_array) + if len(args) > 1: + y = args[1] + if y is not None: + y_sub = filter_by_list(y, is_valid_array) + else: + y_sub = None + x_new = func(obj, x_sub, y_sub, **kwargs) + else: + x_new = func(obj, x_sub, **kwargs) + + if x_new is None: # fit may not return anything + return None + new_pos = np.where(is_valid_array)[0] + if isinstance(x_new, np.ndarray) and isinstance(x, NumpyArrayWithInvalidInstances): + if x_new.shape[0] != x.value_array.shape[0]: + raise AssertionError("Numer of rows is not as expected.") + x.value_array = x_new + return x + if isinstance(x, (list, NumpyArrayWithInvalidInstances)): + x_list = batch_update_sequence(x, x_new, new_pos) + else: + x_array = np.array(x) + x_array[is_valid_array] = x_new + x_list = list(x_array) + if isinstance(x_new, NumpyArrayWithInvalidInstances): + return NumpyArrayWithInvalidInstances(x_list) + return x_list + return wrapper + + +def filter_rows( + X: Sequence[_T], y: Sequence[_U] +) -> tuple[Sequence[_T], Sequence[_U]]: + is_valid_array = get_is_valid_array(X) + x_new = filter_by_list(X, is_valid_array) + y_new = filter_by_list(y, is_valid_array) + return x_new, y_new + + +def get_is_valid_array(item_list: Sequence[Any]) -> npt.NDArray[np.bool_]: + is_valid_list = [] + for i, item in enumerate(item_list): + if not isinstance(item, InvalidInstance): + is_valid_list.append(True) + else: + is_valid_list.append(False) + return np.array(is_valid_list, dtype=bool) + + diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 1c75ba5..b6537e7 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -8,6 +8,8 @@ from scikit_mol.core import check_transform_input, feature_names_default_mol ,DEFAULT_MOL_COLUMN_NAME +from scikit_mol._invalid import InvalidInstance + class SmilesToMolTransformer(BaseEstimator, TransformerMixin): @@ -62,9 +64,15 @@ def _transform(self, X): if mol: X_out.append(mol) else: - raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first') - - return np.array(X_out).reshape(-1,1) + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol: + errors = Chem.DetectChemistryProblems(mol) + error_message = "\n".join(error.Message() for error in errors) + message = f"Invalid SMILES: {error_message}" + else: + message = f"Invalid SMILES: {smiles}" + X_out.append(InvalidInstance(str(self), message)) + return X_out @check_transform_input def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.? diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 767bfc6..65b08f8 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -18,9 +18,11 @@ from sklearn.base import BaseEstimator, TransformerMixin from scikit_mol.core import check_transform_input +from scikit_mol._invalid import NumpyArrayWithInvalidInstances, rdkit_error_handling from abc import ABC, abstractmethod + _PATTERN_FINGERPRINT_TRANSFORMER = re.compile(r"^(?P\w+)FingerprintTransformer$") #%% @@ -92,10 +94,10 @@ def fit(self, X, y=None): @check_transform_input def _transform(self, X): - arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) + arr_list = [] for i, mol in enumerate(X): - arr[i,:] = self._transform_mol(mol) - return arr + arr_list.append(self._transform_mol(mol)) + return NumpyArrayWithInvalidInstances(arr_list) def _transform_sparse(self, X): arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) @@ -104,6 +106,7 @@ def _transform_sparse(self, X): return lil_matrix(arr) + @rdkit_error_handling def transform(self, X, y=None): """Transform a list of RDKit molecule objects into a fingerprint array @@ -134,10 +137,10 @@ def transform(self, X, y=None): # TODO: create "transform_parallel" function in the core module, # and use it here and in the descriptors transformer #x_chunks = [np.array(x).reshape(-1, 1) for x in x_chunks] - arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) - - arr = np.concatenate(arrays) - return arr + arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) + arr_list = [] + arr_list.extend(arrays) + return arr_list class MACCSKeysFingerprintTransformer(FpsTransformer): diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 70eac51..866a9aa 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -1,7 +1,9 @@ #For a non-scikit-learn check smiles sanitizer class + import pandas as pd from rdkit import Chem + class CheckSmilesSanitazion: def __init__(self, return_mol=False): self.return_mol = return_mol diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py new file mode 100644 index 0000000..1899e27 --- /dev/null +++ b/scikit_mol/wrapper.py @@ -0,0 +1,95 @@ +"""Wrapper for sklearn estimators and pipelines to handle errors.""" + +from abc import ABC +from typing import Any + +import numpy as np +from sklearn.base import BaseEstimator +from sklearn.pipeline import Pipeline +from sklearn.utils.metaestimators import available_if + +from scikit_mol._invalid import rdkit_error_handling, InvalidInstance, NumpyArrayWithInvalidInstances + + +class AbstractWrapper(BaseEstimator, ABC): + """ + Abstract class for the wrapper of sklearn objects. + + Attributes + ---------- + model: BaseEstimator | Pipeline + The wrapped model or pipeline. + """ + model: BaseEstimator | Pipeline + + def __init__(self, replace_invalid: bool, replace_value: Any = np.nan): + """Initialize the AbstractWrapper. + + Parameters + ---------- + replace_invalid: bool + Whether to replace or remove errors + replace_value: Any, default=np.nan + If replace_invalid==True, insert this value on the erroneous instance. + """ + self.replace_invalid = replace_invalid + self.replace_value = replace_value + + @rdkit_error_handling + def fit(self, X, y, **fit_params) -> Any: + return self.model.fit(X, y, **fit_params) + + def has_predict(self) -> bool: + return hasattr(self.model, "predict") + + def has_fit_predict(self) -> bool: + return hasattr(self.model, "fit_predict") + + +class WrappedTransformer(AbstractWrapper): + """Wrapper for sklearn transformer objects.""" + + def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan): + """Initialize the WrappedTransformer. + + Parameters + ---------- + model: BaseEstimator + Wrapped model to be protected against Errors. + replace_invalid: bool + Whether to replace or remove errors + replace_value: Any, default=np.nan + If replace_invalid==True, insert this value on the erroneous instance. + """ + super().__init__(replace_invalid=replace_invalid, replace_value=replace_value) + self.model = model + + def has_transform(self) -> bool: + return hasattr(self.model, "transform") + + def has_fit_transform(self) -> bool: + return hasattr(self.model, "fit_transform") + + @available_if(has_transform) + @rdkit_error_handling + def transform(self, X): + return self.model.transform(X) + + @rdkit_error_handling + def _fit_transform(self, X, y): + return self.model.fit_transform(X, y) + + @available_if(has_fit_transform) + def fit_transform(self, X, y=None): + out = self._fit_transform(X,y) + if not self.replace_invalid: + return out + + if isinstance(out, NumpyArrayWithInvalidInstances): + return out.array_filled_with(self.replace_value) + + if isinstance(out, list): + return [self.replace_value if isinstance(v, InvalidInstance) else v for v in out] + + + diff --git a/tests/fixtures.py b/tests/fixtures.py index 8cf0751..2852068 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -63,8 +63,9 @@ def chiral_smiles_list(): #Need to be a certain size, so the fingerprints reacts 'N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]']] @pytest.fixture -def invalid_smiles_list(smiles_list): - smiles_list.append('Invalid') +def invalid_smiles_list(): + smiles_list = ['S-CC', 'Invalid'] + smiles_list.extend(_SMILES_LIST) return smiles_list _MOLS_LIST = [Chem.MolFromSmiles(smiles) for smiles in _SMILES_LIST] diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py new file mode 100644 index 0000000..b848659 --- /dev/null +++ b/tests/test_invalid_handling.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest +from sklearn.decomposition import PCA +from sklearn.pipeline import Pipeline + +from fixtures import smiles_list, invalid_smiles_list +from scikit_mol.conversions import SmilesToMolTransformer +from scikit_mol.fingerprints import MorganFingerprintTransformer +from scikit_mol.wrapper import WrappedTransformer +from scikit_mol._invalid import NumpyArrayWithInvalidInstances +from test_invalid_helpers.invalid_transformer import TestInvalidTransformer + + +@pytest.fixture +def smilestofp_pipeline(): + pipeline = Pipeline( + [ + ("smiles_to_mol", SmilesToMolTransformer()), + ("remove_sulfur", TestInvalidTransformer()), + ("mol_2_fp", MorganFingerprintTransformer()), + ("PCA", WrappedTransformer(PCA(2), replace_invalid=True)) + ] + + ) + return pipeline + + +def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline): + + smilestofp_pipeline.set_params() + mol_pca = smilestofp_pipeline.fit_transform(smiles_list) + error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list) + + if mol_pca.shape != (len(smiles_list), 2): + raise ValueError("The PCA does not return the proper dimensions.") + if isinstance(error_mol_pca, NumpyArrayWithInvalidInstances): + raise TypeError("The Errors were not properly remove from the output array.") + + expected_nans = np.array([[0, 0, 1, 1], [0, 1, 0, 1]]) + if not np.all(np.equal(expected_nans, np.where(np.isnan(error_mol_pca)))): + raise ValueError("Errors were replaced on the wrong positions.") + + non_nan_rows = ~np.any(np.isnan(error_mol_pca), axis=1) + if not np.all(np.isclose(mol_pca, error_mol_pca[non_nan_rows, :])): + raise ValueError("Removing errors introduces changes in the PCA output.") diff --git a/tests/test_invalid_helpers/__init__.py b/tests/test_invalid_helpers/__init__.py new file mode 100644 index 0000000..d0c583a --- /dev/null +++ b/tests/test_invalid_helpers/__init__.py @@ -0,0 +1 @@ +"""Initialize module for helper classes and functions used to test the handling of invalid inputs.""" diff --git a/tests/test_invalid_helpers/invalid_transformer.py b/tests/test_invalid_helpers/invalid_transformer.py new file mode 100644 index 0000000..ffe2738 --- /dev/null +++ b/tests/test_invalid_helpers/invalid_transformer.py @@ -0,0 +1,44 @@ +from typing import Optional, Sequence +from sklearn.base import BaseEstimator, TransformerMixin +from rdkit import Chem + +from scikit_mol._invalid import ( + InvalidInstance, + rdkit_error_handling, +) + + +class TestInvalidTransformer(BaseEstimator, TransformerMixin): + """This class is ment for tesing purposes only. + + All molecules with element number are returned as invalid instance. + + Attributes + --------- + atomic_number_set: set[int] + Atomic numbers which upon occurrence in the molecule make it invalid. + """ + + atomic_number_set: set[int] + + def __init__(self, atomic_number_set: Sequence[int] | None = None) -> None: + if atomic_number_set is None: + atomic_number_set = {16} + self.atomic_number_set = set(atomic_number_set) + + def _transform_mol(self, mol: Chem.Mol) -> Chem.Mol | InvalidInstance: + unique_elements = {atom.GetAtomicNum() for atom in mol.GetAtoms()} + forbidden_elements = self.atomic_number_set & unique_elements + if forbidden_elements: + return InvalidInstance(str(self), f"Molecule contains {forbidden_elements}") + return mol + + @rdkit_error_handling + def transform(self, X: list[Chem.Mol]) -> list[Chem.Mol | InvalidInstance]: + return [self._transform_mol(mol) for mol in X] + + def fit(self, X, y, fit_params): + pass + + def fit_transform(self, X, y=None, **fit_params): + return self.transform(X)