Skip to content

Commit

Permalink
Merge pull request #42 from c-feldmann/handle_invalid_molecules
Browse files Browse the repository at this point in the history
Handle invalid molecules. Draft, examine pytest errors and review.
  • Loading branch information
EBjerrum authored Sep 17, 2024
2 parents 17ef6d1 + 9e72d67 commit e076a1c
Show file tree
Hide file tree
Showing 15 changed files with 401 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions .idea/scikit-mol.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

146 changes: 146 additions & 0 deletions scikit_mol/_invalid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from abc import ABC
from typing import Any, Callable, NamedTuple, Sequence, TypeVar

import numpy as np
import numpy.typing as npt

_T = TypeVar("_T")
_U = TypeVar("_U")


class InvalidInstance(NamedTuple):
"""
The InvalidInstance represents objects which raised an error during a pipeline step.
"""
pipeline_step: str
error: str


class NumpyArrayWithInvalidInstances:
"""
The NumpyArrayWithInvalidInstances is
"""
is_valid_array: npt.NDArray[np.bool_]
invalid_list: list[InvalidInstance]
value_array: npt.NDArray[Any]

def __init__(self, array_list: list[npt.NDArray[Any] | InvalidInstance]):
self.is_valid_array = get_is_valid_array(array_list)
valid_vector_list = filter_by_list(array_list, self.is_valid_array)
self.value_array = np.vstack(valid_vector_list)
self.invalid_list = filter_by_list(array_list, ~self.is_valid_array)

def __len__(self):
return self.is_valid_array.shape[0]

def __getitem__(self, item: int) -> npt.NDArray[Any] | InvalidInstance:
n_invalids_prior = sum(~self.is_valid_array[:item - 1])
if self.is_valid_array[item]:
return self.value_array[item - n_invalids_prior]
return self.invalid_list[n_invalids_prior + 1]

def __setitem__(self, key: int, value: npt.NDArray[Any] | InvalidInstance) -> None:
n_invalids_prior = sum(~self.is_valid_array[:key - 1])
if isinstance(value, InvalidInstance):
if self.is_valid_array[key]:
self.value_array = np.delete(self.value_array, key - n_invalids_prior)
self.is_valid_array[key] = False
self.invalid_list.insert(n_invalids_prior + 1, value)
else:
self.invalid_list[n_invalids_prior + 1] = value
else:
if self.is_valid_array[key]:
self.value_array[key - n_invalids_prior] = value
else:
self.value_array = np.insert(self.value_array, key - n_invalids_prior, value)
del(self.invalid_list[n_invalids_prior + 1])
self.is_valid_array[key] = True

def array_filled_with(self, fill_value) -> npt.NDArray[Any]:
out = np.full((len(self.is_valid_array), self.value_array.shape[1]), fill_value)
out[self.is_valid_array] = self.value_array
return out


def batch_update_sequence(
old_list: list[npt.NDArray[Any] | InvalidInstance] | NumpyArrayWithInvalidInstances,
new_values: list[Any],
value_indices: npt.NDArray[np.int_],
):
old_list = list(old_list) # Make shallow copy of list to avoid inplace changes.
for new_value, idx in zip(new_values, value_indices, strict=True):
old_list[idx] = new_value
return old_list


def filter_by_list(item_list, is_valid_array: npt.NDArray[np.bool_]):
if isinstance(item_list, np.ndarray):
return item_list[is_valid_array]

item_list_new = []
for item, is_valid in zip(item_list, is_valid_array):
if is_valid:
item_list_new.append(item)
return item_list_new


# Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], Sequence[Any]]
# ) -> Callable[[Sequence[Any], Sequence[Any], dict[str, Any]], npt.NDArray[Any]]
def rdkit_error_handling(func):
def wrapper(obj, *args, **kwargs):
x = args[0]
if isinstance(x, NumpyArrayWithInvalidInstances):
is_valid_array = x.is_valid_array
x_sub = x.value_array
else:
is_valid_array = get_is_valid_array(x)
x_sub = filter_by_list(x, is_valid_array)
if len(args) > 1:
y = args[1]
if y is not None:
y_sub = filter_by_list(y, is_valid_array)
else:
y_sub = None
x_new = func(obj, x_sub, y_sub, **kwargs)
else:
x_new = func(obj, x_sub, **kwargs)

if x_new is None: # fit may not return anything
return None
new_pos = np.where(is_valid_array)[0]
if isinstance(x_new, np.ndarray) and isinstance(x, NumpyArrayWithInvalidInstances):
if x_new.shape[0] != x.value_array.shape[0]:
raise AssertionError("Numer of rows is not as expected.")
x.value_array = x_new
return x
if isinstance(x, (list, NumpyArrayWithInvalidInstances)):
x_list = batch_update_sequence(x, x_new, new_pos)
else:
x_array = np.array(x)
x_array[is_valid_array] = x_new
x_list = list(x_array)
if isinstance(x_new, NumpyArrayWithInvalidInstances):
return NumpyArrayWithInvalidInstances(x_list)
return x_list
return wrapper


def filter_rows(
X: Sequence[_T], y: Sequence[_U]
) -> tuple[Sequence[_T], Sequence[_U]]:
is_valid_array = get_is_valid_array(X)
x_new = filter_by_list(X, is_valid_array)
y_new = filter_by_list(y, is_valid_array)
return x_new, y_new


def get_is_valid_array(item_list: Sequence[Any]) -> npt.NDArray[np.bool_]:
is_valid_list = []
for i, item in enumerate(item_list):
if not isinstance(item, InvalidInstance):
is_valid_list.append(True)
else:
is_valid_list.append(False)
return np.array(is_valid_list, dtype=bool)


14 changes: 11 additions & 3 deletions scikit_mol/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from scikit_mol.core import check_transform_input, feature_names_default_mol ,DEFAULT_MOL_COLUMN_NAME

from scikit_mol._invalid import InvalidInstance


class SmilesToMolTransformer(BaseEstimator, TransformerMixin):

Expand Down Expand Up @@ -62,9 +64,15 @@ def _transform(self, X):
if mol:
X_out.append(mol)
else:
raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first')

return np.array(X_out).reshape(-1,1)
mol = Chem.MolFromSmiles(smiles, sanitize=False)
if mol:
errors = Chem.DetectChemistryProblems(mol)
error_message = "\n".join(error.Message() for error in errors)
message = f"Invalid SMILES: {error_message}"
else:
message = f"Invalid SMILES: {smiles}"
X_out.append(InvalidInstance(str(self), message))
return X_out

@check_transform_input
def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.?
Expand Down
17 changes: 10 additions & 7 deletions scikit_mol/fingerprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

from sklearn.base import BaseEstimator, TransformerMixin
from scikit_mol.core import check_transform_input
from scikit_mol._invalid import NumpyArrayWithInvalidInstances, rdkit_error_handling

from abc import ABC, abstractmethod


_PATTERN_FINGERPRINT_TRANSFORMER = re.compile(r"^(?P<fingerprint_name>\w+)FingerprintTransformer$")

#%%
Expand Down Expand Up @@ -92,10 +94,10 @@ def fit(self, X, y=None):

@check_transform_input
def _transform(self, X):
arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT)
arr_list = []
for i, mol in enumerate(X):
arr[i,:] = self._transform_mol(mol)
return arr
arr_list.append(self._transform_mol(mol))
return NumpyArrayWithInvalidInstances(arr_list)

def _transform_sparse(self, X):
arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT)
Expand All @@ -104,6 +106,7 @@ def _transform_sparse(self, X):

return lil_matrix(arr)

@rdkit_error_handling
def transform(self, X, y=None):
"""Transform a list of RDKit molecule objects into a fingerprint array
Expand Down Expand Up @@ -134,10 +137,10 @@ def transform(self, X, y=None):
# TODO: create "transform_parallel" function in the core module,
# and use it here and in the descriptors transformer
#x_chunks = [np.array(x).reshape(-1, 1) for x in x_chunks]
arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks])

arr = np.concatenate(arrays)
return arr
arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks])
arr_list = []
arr_list.extend(arrays)
return arr_list


class MACCSKeysFingerprintTransformer(FpsTransformer):
Expand Down
2 changes: 2 additions & 0 deletions scikit_mol/utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#For a non-scikit-learn check smiles sanitizer class

import pandas as pd
from rdkit import Chem


class CheckSmilesSanitazion:
def __init__(self, return_mol=False):
self.return_mol = return_mol
Expand Down
95 changes: 95 additions & 0 deletions scikit_mol/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Wrapper for sklearn estimators and pipelines to handle errors."""

from abc import ABC
from typing import Any

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.pipeline import Pipeline
from sklearn.utils.metaestimators import available_if

from scikit_mol._invalid import rdkit_error_handling, InvalidInstance, NumpyArrayWithInvalidInstances


class AbstractWrapper(BaseEstimator, ABC):
"""
Abstract class for the wrapper of sklearn objects.
Attributes
----------
model: BaseEstimator | Pipeline
The wrapped model or pipeline.
"""
model: BaseEstimator | Pipeline

def __init__(self, replace_invalid: bool, replace_value: Any = np.nan):
"""Initialize the AbstractWrapper.
Parameters
----------
replace_invalid: bool
Whether to replace or remove errors
replace_value: Any, default=np.nan
If replace_invalid==True, insert this value on the erroneous instance.
"""
self.replace_invalid = replace_invalid
self.replace_value = replace_value

@rdkit_error_handling
def fit(self, X, y, **fit_params) -> Any:
return self.model.fit(X, y, **fit_params)

def has_predict(self) -> bool:
return hasattr(self.model, "predict")

def has_fit_predict(self) -> bool:
return hasattr(self.model, "fit_predict")


class WrappedTransformer(AbstractWrapper):
"""Wrapper for sklearn transformer objects."""

def __init__(self, model: BaseEstimator, replace_invalid: bool = False, replace_value=np.nan):
"""Initialize the WrappedTransformer.
Parameters
----------
model: BaseEstimator
Wrapped model to be protected against Errors.
replace_invalid: bool
Whether to replace or remove errors
replace_value: Any, default=np.nan
If replace_invalid==True, insert this value on the erroneous instance.
"""
super().__init__(replace_invalid=replace_invalid, replace_value=replace_value)
self.model = model

def has_transform(self) -> bool:
return hasattr(self.model, "transform")

def has_fit_transform(self) -> bool:
return hasattr(self.model, "fit_transform")

@available_if(has_transform)
@rdkit_error_handling
def transform(self, X):
return self.model.transform(X)

@rdkit_error_handling
def _fit_transform(self, X, y):
return self.model.fit_transform(X, y)

@available_if(has_fit_transform)
def fit_transform(self, X, y=None):
out = self._fit_transform(X,y)
if not self.replace_invalid:
return out

if isinstance(out, NumpyArrayWithInvalidInstances):
return out.array_filled_with(self.replace_value)

if isinstance(out, list):
return [self.replace_value if isinstance(v, InvalidInstance) else v for v in out]



5 changes: 3 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def chiral_smiles_list(): #Need to be a certain size, so the fingerprints reacts
'N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]']]

@pytest.fixture
def invalid_smiles_list(smiles_list):
smiles_list.append('Invalid')
def invalid_smiles_list():
smiles_list = ['S-CC', 'Invalid']
smiles_list.extend(_SMILES_LIST)
return smiles_list

_MOLS_LIST = [Chem.MolFromSmiles(smiles) for smiles in _SMILES_LIST]
Expand Down
Loading

0 comments on commit e076a1c

Please sign in to comment.