diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..9aa2337
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..13c8595
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/scikit-mol.iml b/.idea/scikit-mol.iml
new file mode 100644
index 0000000..fb6d745
--- /dev/null
+++ b/.idea/scikit-mol.iml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/scikit_mol/_invalid.py b/scikit_mol/_invalid.py
new file mode 100644
index 0000000..e1a91e5
--- /dev/null
+++ b/scikit_mol/_invalid.py
@@ -0,0 +1,146 @@
+from abc import ABC
+from typing import Any, Callable, NamedTuple, Sequence, TypeVar
+
+import numpy as np
+import numpy.typing as npt
+
+_T = TypeVar("_T")
+_U = TypeVar("_U")
+
+
+class InvalidInstance(NamedTuple):
+ """
+ The InvalidInstance represents objects which raised an error during a pipeline step.
+ """
+ pipeline_step: str
+ error: str
+
+
+class NumpyArrayWithInvalidInstances:
+ """
+ The NumpyArrayWithInvalidInstances is
+ """
+ is_valid_array: npt.NDArray[np.bool_]
+ invalid_list: list[InvalidInstance]
+ value_array: npt.NDArray[Any]
+
+ def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]):
+ self.is_valid_array = get_is_valid_array(array_list)
+ valid_vector_list = filter_by_list(array_list, self.is_valid_array)
+ self.value_array = np.vstack(valid_vector_list)
+ self.invalid_list = filter_by_list(array_list, ~self.is_valid_array)
+
+ def __len__(self):
+ return self.is_valid_array.shape[0]
+
+ def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance:
+ n_invalids_prior = sum(~self.is_valid_array[:item - 1])
+ if self.is_valid_array[item]:
+ return self.value_array[item - n_invalids_prior]
+ return self.invalid_list[n_invalids_prior + 1]
+
+ def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> None:
+ n_invalids_prior = sum(~self.is_valid_array[:key - 1])
+ if isinstance(value, InvalidInstance):
+ if self.is_valid_array[key]:
+ self.value_array = np.delete(self.value_array, key - n_invalids_prior)
+ self.is_valid_array[key] = False
+ self.invalid_list.insert(n_invalids_prior + 1, value)
+ else:
+ self.invalid_list[n_invalids_prior + 1] = value
+ else:
+ if self.is_valid_array[key]:
+ self.value_array[key - n_invalids_prior] = value
+ else:
+ self.value_array = np.insert(self.value_array, key - n_invalids_prior, value)
+ del(self.invalid_list[n_invalids_prior + 1])
+ self.is_valid_array[key] = True
+
+ def array_filled_with(self, fill_value) -> npt.NDArray[Any]:
+ out = np.full((len(self.is_valid_array), self.value_array.shape[1]), fill_value)
+ out[self.is_valid_array] = self.value_array
+ return out
+
+
+def batch_update_sequence(
+ old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances,
+ new_values: list[Any],
+ value_indices: npt.NDArray[np.int_],
+ ):
+ old_list = list(old_list) # Make shallow copy of list to avoid inplace changes.
+ for new_value, idx in zip(new_values, value_indices, strict=True):
+ old_list[idx] = new_value
+ return old_list
+
+
+def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]):
+ if isinstance(item_list, np.ndarray):
+ return item_list[is_valid_array]
+
+ item_list_new = []
+ for item, is_valid in zip(item_list, is_valid_array):
+ if is_valid:
+ item_list_new.append(item)
+ return item_list_new
+
+
+# Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]]
+# ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]]
+def rdkit_error_handling(func):
+ def wrapper(obj, *args, **kwargs):
+ x = args[0]
+ if isinstance(x, NumpyArrayWithInvalidInstances):
+ is_valid_array = x.is_valid_array
+ x_sub = x.value_array
+ else:
+ is_valid_array = get_is_valid_array(x)
+ x_sub = filter_by_list(x, is_valid_array)
+ if len(args) > 1:
+ y = args[1]
+ if y is not None:
+ y_sub = filter_by_list(y, is_valid_array)
+ else:
+ y_sub = None
+ x_new = func(obj, x_sub, y_sub, **kwargs)
+ else:
+ x_new = func(obj, x_sub, **kwargs)
+
+ if x_new is None: # fit may not return anything
+ return None
+ new_pos = np.where(is_valid_array)[0]
+ if isinstance(x_new, np.ndarray) and isinstance(x, NumpyArrayWithInvalidInstances):
+ if x_new.shape[0] != x.value_array.shape[0]:
+ raise AssertionError("Numer of rows is not as expected.")
+ x.value_array = x_new
+ return x
+ if isinstance(x, (list, NumpyArrayWithInvalidInstances)):
+ x_list = batch_update_sequence(x, x_new, new_pos)
+ else:
+ x_array = np.array(x)
+ x_array[is_valid_array] = x_new
+ x_list = list(x_array)
+ if isinstance(x_new, NumpyArrayWithInvalidInstances):
+ return NumpyArrayWithInvalidInstances(x_list)
+ return x_list
+ return wrapper
+
+
+def filter_rows(
+ X: Sequence[_T], y: Sequence[_U]
+) -> tuple[Sequence[_T], Sequence[_U]]:
+ is_valid_array = get_is_valid_array(X)
+ x_new = filter_by_list(X, is_valid_array)
+ y_new = filter_by_list(y, is_valid_array)
+ return x_new, y_new
+
+
+def get_is_valid_array(item_list: Sequence[Any]) -> npt.NDArray[np.bool_]:
+ is_valid_list = []
+ for i, item in enumerate(item_list):
+ if not isinstance(item, InvalidInstance):
+ is_valid_list.append(True)
+ else:
+ is_valid_list.append(False)
+ return np.array(is_valid_list, dtype=bool)
+
+
diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py
index 1c75ba5..b6537e7 100644
--- a/scikit_mol/conversions.py
+++ b/scikit_mol/conversions.py
@@ -8,6 +8,8 @@
from scikit_mol.core import check_transform_input, feature_names_default_mol ,DEFAULT_MOL_COLUMN_NAME
+from scikit_mol._invalid import InvalidInstance
+
class SmilesToMolTransformer(BaseEstimator, TransformerMixin):
@@ -62,9 +64,15 @@ def _transform(self, X):
if mol:
X_out.append(mol)
else:
- raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first')
-
- return np.array(X_out).reshape(-1,1)
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
+ if mol:
+ errors = Chem.DetectChemistryProblems(mol)
+ error_message = "\n".join(error.Message() for error in errors)
+ message = f"Invalid SMILES: {error_message}"
+ else:
+ message = f"Invalid SMILES: {smiles}"
+ X_out.append(InvalidInstance(str(self), message))
+ return X_out
@check_transform_input
def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.?
diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py
index 767bfc6..65b08f8 100644
--- a/scikit_mol/fingerprints.py
+++ b/scikit_mol/fingerprints.py
@@ -18,9 +18,11 @@
from sklearn.base import BaseEstimator, TransformerMixin
from scikit_mol.core import check_transform_input
+from scikit_mol._invalid import NumpyArrayWithInvalidInstances, rdkit_error_handling
from abc import ABC, abstractmethod
+
_PATTERN_FINGERPRINT_TRANSFORMER = re.compile(r"^(?P\w+)FingerprintTransformer$")
#%%
@@ -92,10 +94,10 @@ def fit(self, X, y=None):
@check_transform_input
def _transform(self, X):
- arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT)
+ arr_list = []
for i, mol in enumerate(X):
- arr[i,:] = self._transform_mol(mol)
- return arr
+ arr_list.append(self._transform_mol(mol))
+ return NumpyArrayWithInvalidInstances(arr_list)
def _transform_sparse(self, X):
arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT)
@@ -104,6 +106,7 @@ def _transform_sparse(self, X):
return lil_matrix(arr)
+ @rdkit_error_handling
def transform(self, X, y=None):
"""Transform a list of RDKit molecule objects into a fingerprint array
@@ -134,10 +137,10 @@ def transform(self, X, y=None):
# TODO: create "transform_parallel" function in the core module,
# and use it here and in the descriptors transformer
#x_chunks = [np.array(x).reshape(-1, 1) for x in x_chunks]
- arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks])
-
- arr = np.concatenate(arrays)
- return arr
+ arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks])
+ arr_list = []
+ arr_list.extend(arrays)
+ return arr_list
class MACCSKeysFingerprintTransformer(FpsTransformer):
diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py
index 70eac51..866a9aa 100644
--- a/scikit_mol/utilities.py
+++ b/scikit_mol/utilities.py
@@ -1,7 +1,9 @@
#For a non-scikit-learn check smiles sanitizer class
+
import pandas as pd
from rdkit import Chem
+
class CheckSmilesSanitazion:
def __init__(self, return_mol=False):
self.return_mol = return_mol
diff --git a/scikit_mol/wrapper.py b/scikit_mol/wrapper.py
new file mode 100644
index 0000000..1899e27
--- /dev/null
+++ b/scikit_mol/wrapper.py
@@ -0,0 +1,95 @@
+"""Wrapper for sklearn estimators and pipelines to handle errors."""
+
+from abc import ABC
+from typing import Any
+
+import numpy as np
+from sklearn.base import BaseEstimator
+from sklearn.pipeline import Pipeline
+from sklearn.utils.metaestimators import available_if
+
+from scikit_mol._invalid import rdkit_error_handling, InvalidInstance, NumpyArrayWithInvalidInstances
+
+
+class AbstractWrapper(BaseEstimator, ABC):
+ """
+ Abstract class for the wrapper of sklearn objects.
+
+ Attributes
+ ----------
+ model: BaseEstimator | Pipeline
+ The wrapped model or pipeline.
+ """
+ model: BaseEstimator | Pipeline
+
+ def __init__(self, replace_invalid: bool, replace_value: Any = np.nan):
+ """Initialize the AbstractWrapper.
+
+ Parameters
+ ----------
+ replace_invalid: bool
+ Whether to replace or remove errors
+ replace_value: Any, default=np.nan
+ If replace_invalid==True, insert this value on the erroneous instance.
+ """
+ self.replace_invalid = replace_invalid
+ self.replace_value = replace_value
+
+ @rdkit_error_handling
+ def fit(self, X, y, **fit_params) -> Any:
+ return self.model.fit(X, y, **fit_params)
+
+ def has_predict(self) -> bool:
+ return hasattr(self.model, "predict")
+
+ def has_fit_predict(self) -> bool:
+ return hasattr(self.model, "fit_predict")
+
+
+class WrappedTransformer(AbstractWrapper):
+ """Wrapper for sklearn transformer objects."""
+
+ def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan):
+ """Initialize the WrappedTransformer.
+
+ Parameters
+ ----------
+ model: BaseEstimator
+ Wrapped model to be protected against Errors.
+ replace_invalid: bool
+ Whether to replace or remove errors
+ replace_value: Any, default=np.nan
+ If replace_invalid==True, insert this value on the erroneous instance.
+ """
+ super().__init__(replace_invalid=replace_invalid, replace_value=replace_value)
+ self.model = model
+
+ def has_transform(self) -> bool:
+ return hasattr(self.model, "transform")
+
+ def has_fit_transform(self) -> bool:
+ return hasattr(self.model, "fit_transform")
+
+ @available_if(has_transform)
+ @rdkit_error_handling
+ def transform(self, X):
+ return self.model.transform(X)
+
+ @rdkit_error_handling
+ def _fit_transform(self, X, y):
+ return self.model.fit_transform(X, y)
+
+ @available_if(has_fit_transform)
+ def fit_transform(self, X, y=None):
+ out = self._fit_transform(X,y)
+ if not self.replace_invalid:
+ return out
+
+ if isinstance(out, NumpyArrayWithInvalidInstances):
+ return out.array_filled_with(self.replace_value)
+
+ if isinstance(out, list):
+ return [self.replace_value if isinstance(v, InvalidInstance) else v for v in out]
+
+
+
diff --git a/tests/fixtures.py b/tests/fixtures.py
index 8cf0751..2852068 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -63,8 +63,9 @@ def chiral_smiles_list(): #Need to be a certain size, so the fingerprints reacts
'N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]']]
@pytest.fixture
-def invalid_smiles_list(smiles_list):
- smiles_list.append('Invalid')
+def invalid_smiles_list():
+ smiles_list = ['S-CC', 'Invalid']
+ smiles_list.extend(_SMILES_LIST)
return smiles_list
_MOLS_LIST = [Chem.MolFromSmiles(smiles) for smiles in _SMILES_LIST]
diff --git a/tests/test_invalid_handling.py b/tests/test_invalid_handling.py
new file mode 100644
index 0000000..b848659
--- /dev/null
+++ b/tests/test_invalid_handling.py
@@ -0,0 +1,45 @@
+import numpy as np
+import pytest
+from sklearn.decomposition import PCA
+from sklearn.pipeline import Pipeline
+
+from fixtures import smiles_list, invalid_smiles_list
+from scikit_mol.conversions import SmilesToMolTransformer
+from scikit_mol.fingerprints import MorganFingerprintTransformer
+from scikit_mol.wrapper import WrappedTransformer
+from scikit_mol._invalid import NumpyArrayWithInvalidInstances
+from test_invalid_helpers.invalid_transformer import TestInvalidTransformer
+
+
+@pytest.fixture
+def smilestofp_pipeline():
+ pipeline = Pipeline(
+ [
+ ("smiles_to_mol", SmilesToMolTransformer()),
+ ("remove_sulfur", TestInvalidTransformer()),
+ ("mol_2_fp", MorganFingerprintTransformer()),
+ ("PCA", WrappedTransformer(PCA(2), replace_invalid=True))
+ ]
+
+ )
+ return pipeline
+
+
+def test_descriptor_transformer(smiles_list, invalid_smiles_list, smilestofp_pipeline):
+
+ smilestofp_pipeline.set_params()
+ mol_pca = smilestofp_pipeline.fit_transform(smiles_list)
+ error_mol_pca = smilestofp_pipeline.fit_transform(invalid_smiles_list)
+
+ if mol_pca.shape != (len(smiles_list), 2):
+ raise ValueError("The PCA does not return the proper dimensions.")
+ if isinstance(error_mol_pca, NumpyArrayWithInvalidInstances):
+ raise TypeError("The Errors were not properly remove from the output array.")
+
+ expected_nans = np.array([[0, 0, 1, 1], [0, 1, 0, 1]])
+ if not np.all(np.equal(expected_nans, np.where(np.isnan(error_mol_pca)))):
+ raise ValueError("Errors were replaced on the wrong positions.")
+
+ non_nan_rows = ~np.any(np.isnan(error_mol_pca), axis=1)
+ if not np.all(np.isclose(mol_pca, error_mol_pca[non_nan_rows, :])):
+ raise ValueError("Removing errors introduces changes in the PCA output.")
diff --git a/tests/test_invalid_helpers/__init__.py b/tests/test_invalid_helpers/__init__.py
new file mode 100644
index 0000000..d0c583a
--- /dev/null
+++ b/tests/test_invalid_helpers/__init__.py
@@ -0,0 +1 @@
+"""Initialize module for helper classes and functions used to test the handling of invalid inputs."""
diff --git a/tests/test_invalid_helpers/invalid_transformer.py b/tests/test_invalid_helpers/invalid_transformer.py
new file mode 100644
index 0000000..ffe2738
--- /dev/null
+++ b/tests/test_invalid_helpers/invalid_transformer.py
@@ -0,0 +1,44 @@
+from typing import Optional, Sequence
+from sklearn.base import BaseEstimator, TransformerMixin
+from rdkit import Chem
+
+from scikit_mol._invalid import (
+ InvalidInstance,
+ rdkit_error_handling,
+)
+
+
+class TestInvalidTransformer(BaseEstimator, TransformerMixin):
+ """This class is ment for tesing purposes only.
+
+ All molecules with element number are returned as invalid instance.
+
+ Attributes
+ ---------
+ atomic_number_set: set[int]
+ Atomic numbers which upon occurrence in the molecule make it invalid.
+ """
+
+ atomic_number_set: set[int]
+
+ def __init__(self, atomic_number_set: Sequence[int] | None = None) -> None:
+ if atomic_number_set is None:
+ atomic_number_set = {16}
+ self.atomic_number_set = set(atomic_number_set)
+
+ def _transform_mol(self, mol: Chem.Mol) -> Chem.Mol | InvalidInstance:
+ unique_elements = {atom.GetAtomicNum() for atom in mol.GetAtoms()}
+ forbidden_elements = self.atomic_number_set & unique_elements
+ if forbidden_elements:
+ return InvalidInstance(str(self), f"Molecule contains {forbidden_elements}")
+ return mol
+
+ @rdkit_error_handling
+ def transform(self, X: list[Chem.Mol]) -> list[Chem.Mol | InvalidInstance]:
+ return [self._transform_mol(mol) for mol in X]
+
+ def fit(self, X, y, fit_params):
+ pass
+
+ def fit_transform(self, X, y=None, **fit_params):
+ return self.transform(X)