Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle invalid molecules #53

Merged
merged 47 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e8744d2
Add InvalidInstance and add empty file for wrappers
c-feldmann Sep 22, 2023
361284b
Nothing works but I want to save commit
c-feldmann Sep 22, 2023
e5c6a20
first working draft
c-feldmann Sep 22, 2023
db7cbe5
Merge branch 'EBjerrum:main' into main
c-feldmann Sep 23, 2023
a74eeaf
Merge remote-tracking branch 'origin/main' into add_predictor
c-feldmann Sep 23, 2023
902d549
Merge branch 'EBjerrum:main' into main
c-feldmann Sep 13, 2024
893e155
Merge remote-tracking branch 'origin/main' into add_predictor
c-feldmann Sep 13, 2024
2593ca2
Add docstrings
c-feldmann Sep 13, 2024
de5d0d8
refactor unittest
c-feldmann Sep 13, 2024
7ddbfe1
add Transformer which can make molecules invalid
c-feldmann Sep 13, 2024
17f19d4
Resolve merge issues with smiles_list
c-feldmann Sep 13, 2024
6c1d3e6
Add docstring
c-feldmann Sep 13, 2024
591a2ac
Add init
c-feldmann Sep 13, 2024
c713ece
Add Message encountering Errors
c-feldmann Sep 13, 2024
85c9745
Fix Message encountering Errors
c-feldmann Sep 13, 2024
ce55a52
Fix reference to test classes
c-feldmann Sep 13, 2024
9e72d67
Add __len__ to class
c-feldmann Sep 13, 2024
e076a1c
Merge pull request #42 from c-feldmann/handle_invalid_molecules
EBjerrum Sep 17, 2024
0c6cf7e
Simplifying datatypes. Conversions use invalidMol and MACCSkeys are r…
Sep 27, 2024
79f5dab
Added the simplified NanGuardWrapper. Some tests fails, especially ar…
Sep 27, 2024
7bbd8f1
Added support for pandas output. However, the set_output(), does not …
Sep 27, 2024
716c898
implemented the handle_errors flag on smilestomol transformer
Sep 27, 2024
0077ae6
Fixed an error in the tests of the sanitizer
Sep 27, 2024
3a195e6
cleanup of some accidentially added hidden files
Sep 27, 2024
cf98fd8
Added a basic test of an error_handling pipeline
Sep 27, 2024
bd3b262
formatting changes
Sep 27, 2024
bb5c506
Cleanup
Sep 27, 2024
91fff95
Added error handling to fingerprint classes. Also added a utility to …
Sep 27, 2024
50d9004
Updated standardizer and transformer for handling the errors. Still n…
Sep 27, 2024
555afaf
Cleaning up.- We are getting closer
Sep 27, 2024
efbd8b9
Fixed a bug in a test
Sep 27, 2024
4078411
updating smiles_to_mol test case
Oct 3, 2024
dd81ad1
Merge branch 'main' into handle_invalid_molecules
Oct 3, 2024
63bfabe
Updated test of smilestomol to check for the handle_errors capabiliti…
Oct 3, 2024
85ec6ca
Developed a test of the error_handling for fingerprint transformers, …
Oct 3, 2024
70a0598
Added test of the fingerprint classes for error handling
Oct 3, 2024
e1b2557
Changed name to safe_inference mode consistently and fixed the pytest…
Oct 9, 2024
35c33a2
Created test for descriptor transformer and fixed bugs in descriptor …
Oct 9, 2024
8b85989
Added dtype directly as properties on the objects, not on the class.
Oct 10, 2024
d074d2c
Removed deprecated test
Oct 10, 2024
973dea1
Some minor fixes and updates of messages to be more concise
Oct 10, 2024
48d4233
Fixed double parsing of SMILES
Oct 11, 2024
a6076d1
Fixed an issue with pandas output
Oct 11, 2024
08dc417
Created a test and updated name of the module for safeinference
Oct 13, 2024
a735b7c
Updated Notebook and links to notebook in readme (will first work whe…
Oct 13, 2024
f17fd43
Some cleanup
Oct 13, 2024
c4e4c69
Fixed spurious import
Oct 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ There are a collection of notebooks in the notebooks directory which demonstrate
- [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb)
- [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb)
- [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb)
- [Working with pipelines and estimators in safe inference mode for handling prediction on batches with invalid smiles or molecules](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb)

We also put a software note on ChemRxiv. [https://doi.org/10.26434/chemrxiv-2023-fzqwd](https://doi.org/10.26434/chemrxiv-2023-fzqwd)

Expand Down
1,023 changes: 1,023 additions & 0 deletions notebooks/11_safe_inference.ipynb

Large diffs are not rendered by default.

145 changes: 145 additions & 0 deletions notebooks/11_safe_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.1
# kernelspec:
# display_name: vscode
# language: python
# name: python3
# ---

# %% [markdown]
# # Safe inference mode
#
# I think everyone which have worked with SMILES and RDKit sooner or later come across a SMILES that doesn't parse. It can happen if the SMILES was produced with a different toolkit that are less strict with e.g. valence rules, or maybe a characher was missing in the copying from the email. During curation of the dataset for training models, these SMILES need to be identfied and eventually fixed or removed. But what happens when we are finished with our modelling? What kind of molecules and SMILES will a user of the model send for the model in the future when it's in deployment. What kind of SMILES will a generative model create that we need to predict? We don't know and we won't know. So it's kind of crucial to be able to handle these situations. Scikit-Learn models usually simply explodes the entire batch that are being predicted. This is where safe_inference_mode was introduced in Scikit-Mol. With the introduction all transformers got a safe inference mode, where they handle invalid input. How they handle it depends a bit on the transformer, so we will go through the different usual steps and see how things have changed with the introduction of the safe inference mode.
#
# NOTE! In the following demonstration I switch on the safe inference mode individually for demonstration purposes. I would not recommend to do that while building and training models, instead I would switch it on _after_ training and evaluation (more on that later). Otherwise there's a risk to train on the 2% of a dataset that didn't fail....
#
# First some imports and test SMILES and molecules.

# %%
from rdkit import Chem
from scikit_mol.conversions import SmilesToMolTransformer

#We have some deprecation warnings, we are adressing them, but they just distract from this demonstration
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

smiles = ["C1=CC=C(C=C1)F", "C1=CC=C(C=C1)O", "C1=CC=C(C=C1)N", "C1=CC=C(C=C1)Cl"]
smiles_with_invalid = smiles + ["N(C)(C)(C)C", "I'm not a SMILES"]

smi2mol = SmilesToMolTransformer(safe_inference_mode=True)

mols_with_invalid = smi2mol.transform(smiles_with_invalid)
mols_with_invalid

# %% [markdown]
# Without the safe inference mode, the transformation would simply fail, but now we get the expected array back with our RDKit molecules and a last entry which is an object of the type InvalidMol. InvalidMol is simply a placeholder that tells what step failed the conversion and the error. InvalidMol evaluates to `False` in boolean contexts, so it gets easy to filter away and handle in `if`s and list comprehensions. As example:

# %%
[mol for mol in mols_with_invalid if mol]

# %% [markdown]
# or

# %%
mask = mols_with_invalid.astype(bool)
mols_with_invalid[mask]

# %% [markdown]
# Having a failsafe SmilesToMol conversion leads us to next step, featurization. The transformers in safe inference mode now return a NumPy masked array instead of a regular NumPy array. It simply evaluates the incoming mols in a boolean context, so e.g. `None`, `np.nan` and other Python objects that evaluates to False will also get masked (i.e. if you use a dataframe with an ROMol column produced with the PandasTools utility)

# %%
from scikit_mol.fingerprints import MorganFingerprintTransformer

mfp = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True)
fps = mfp.transform(mols_with_invalid)
fps


# %% [markdown]
# However, currently scikit-learn models accepts masked arrays, but they do not respect the mask! So if you fed it directly to the model to train, it would seemingly work, but the invalid samples would all have the fill_value, meaning you could get weird results. Instead we need the last part of the puzzle, the SafeInferenceWrapper class.

# %%
from scikit_mol.safeinference import SafeInferenceWrapper
from sklearn.linear_model import LogisticRegression
import numpy as np

regressor = LogisticRegression()
wrapper = SafeInferenceWrapper(regressor, safe_inference_mode=True)
wrapper.fit(fps, [0,1,0,1,0,1])
wrapper.predict(fps)


# %% [markdown]
#

# %% [markdown]
# The prediction went fine both in fit and in prediction, where the result shows `nan` for the invalid entries. However, please note fit in sage_inference_mode is not recommended in a training session, but you are warned and not blocked, because maybe you know what you do and do it on purpose.
# The SafeInferenceMapper both handles rows that are masked in masked arrays, but also checks rows for nonfinite values and filters these away. Sometimes some descriptors may return a inf or nan, even though the molecule itself is valid. The masking of nonfinite values can be switched off, maybe you are using a model that can handle missing data and only want to filter away invalid molecules.
#
# ## Setting safe_inference_mode post-training
# As I said before I believe in catching errors and fixing those during training, but what do we do when we need to switch on safe inference mode for all objects in a pipeline? There's of course a tool for that, so lets demo that:

# %%
from scikit_mol.safeinference import set_safe_inference_mode
from sklearn.pipeline import Pipeline

pipe = Pipeline([
("smi2mol", SmilesToMolTransformer()),
("mfp", MorganFingerprintTransformer(radius=2, nBits=25)),
("safe_regressor", SafeInferenceWrapper(LogisticRegression()))
])

pipe.fit(smiles, [1,0,1,0])

print("Without safe inference mode:")
try:
pipe.predict(smiles_with_invalid)
except Exception as e:
print("Prediction failed with exception: ", e)
print()

set_safe_inference_mode(pipe, True)

print("With safe inference mode:")
print(pipe.predict(smiles_with_invalid))

# %% [markdown]
# We see that the prediction fail without safe inference mode, and proceeds when it's conveniently set by the `set_safe_inference_mode` utility. The model is now ready for save and reuse in a more failsafe manner :-)

# %% [markdown]
# ## Combining safe_inference_mode with pandas output
# One potential issue can happen when we combine the safe_inference_mode with Pandas output mode of the transformers. It will work, but depending on the batch something surprising can happen due to the way that Pandas converts masked Numpy arrays. Let me demonstrate the issue, first we predict a batch without any errors.

# %%
mfp.set_output(transform="pandas")

mols = smi2mol.transform(smiles)

fps = mfp.transform(mols)
fps

# %% [markdown]
# Then lets see if we transform a batch with an invalid molecule:

# %%
fps = mfp.transform(mols_with_invalid)
fps

# %% [markdown]
# The second output is no longer integers, but floats. As most sklearn models cast input arrays to float32 internally, this difference is likely benign, but that's not guaranteed! Thus if you want to use pandas output for your production models, do check that the final outputs are the same for the valid rows, with and without a single invalid row. Alternatively the dtype for the output of the transformer can be switched to float for consistency.

# %%
mfp_float = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True, dtype=np.float32)
mfp_float.set_output(transform="pandas")
fps = mfp_float.transform(mols)
fps

# %% [markdown]
# I hope this new feature of Scikit-Mol will make it even easier to handle models, even when used in environments without SMILES or molecule validity guarantees.
1 change: 1 addition & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ This is a collection of notebooks in the notebooks directory which demonstrates
- [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb)
- [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb)
- [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb)
- [Working with pipelines and estimators in safe inference mode](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb)
108 changes: 84 additions & 24 deletions scikit_mol/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,44 @@
import multiprocessing
from typing import Union
from rdkit import Chem
from rdkit.rdBase import BlockLogs

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin

from scikit_mol.core import check_transform_input, feature_names_default_mol ,DEFAULT_MOL_COLUMN_NAME
from scikit_mol.core import (
check_transform_input,
feature_names_default_mol,
DEFAULT_MOL_COLUMN_NAME,
InvalidMol,
)

# from scikit_mol._invalid import InvalidMol

class SmilesToMolTransformer(BaseEstimator, TransformerMixin):

def __init__(self, parallel: Union[bool, int] = False):
class SmilesToMolTransformer(BaseEstimator, TransformerMixin):
"""
Transformer for converting SMILES strings to RDKit mol objects.

This transformer can be included in pipelines during development and training,
but the safe inference mode should only be enabled when deploying models for
inference in production environments.

Parameters:
-----------
parallel : Union[bool, int], default=False
If True or int > 1, enables parallel processing.
safe_inference_mode : bool, default=False
If True, enables safeguards for handling invalid data during inference.
This should only be set to True when deploying models to production.
"""

def __init__(
self, parallel: Union[bool, int] = False, safe_inference_mode: bool = False
):
self.parallel = parallel
self.start_method = None #TODO implement handling of start_method
self.start_method = None # TODO implement handling of start_method
self.safe_inference_mode = safe_inference_mode

@feature_names_default_mol
def get_feature_names_out(self, input_features=None):
Expand All @@ -39,39 +65,73 @@ def transform(self, X_smiles_list, y=None):
Raises
------
ValueError
Raises ValueError if a SMILES string is unparsable by RDKit
Raises ValueError if a SMILES string is unparsable by RDKit and safe_inference_mode is False
"""


if not self.parallel:
return self._transform(X_smiles_list)
elif self.parallel:
n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects
n_chunks = n_processes*2 if n_processes is not None else multiprocessing.cpu_count()*2 #TODO, tune the number of chunks per child process
n_processes = (
self.parallel if self.parallel > 1 else None
) # Pool(processes=None) autodetects
n_chunks = (
n_processes * 2
if n_processes is not None
else multiprocessing.cpu_count() * 2
) # TODO, tune the number of chunks per child process
with get_context(self.start_method).Pool(processes=n_processes) as pool:
x_chunks = np.array_split(X_smiles_list, n_chunks)
arrays = pool.map(self._transform, x_chunks) #is the helper function a safer way of handling the picklind and child process communication
arr = np.concatenate(arrays)
return arr
x_chunks = np.array_split(X_smiles_list, n_chunks)
arrays = pool.map(
self._transform, x_chunks
) # is the helper function a safer way of handling the picklind and child process communication
arr = np.concatenate(arrays)
return arr

@check_transform_input
def _transform(self, X):
X_out = []
for smiles in X:
mol = Chem.MolFromSmiles(smiles)
if mol:
X_out.append(mol)
else:
raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first')

return np.array(X_out).reshape(-1,1)
with BlockLogs():
for smiles in X:
mol = Chem.MolFromSmiles(smiles, sanitize=False)
if mol:
errors = Chem.DetectChemistryProblems(mol)
if errors:
error_message = "\n".join(error.Message() for error in errors)
message = f"Invalid Molecule: {error_message}"
X_out.append(InvalidMol(str(self), message))
else:
Chem.SanitizeMol(mol)
X_out.append(mol)
else:
message = f"Invalid SMILES: {smiles}"
X_out.append(InvalidMol(str(self), message))
if not self.safe_inference_mode and not all(X_out):
fails = [x for x in X_out if not x]
raise ValueError(
f"Invalid input found: {fails}."
) # TODO with this approach we get all errors, but we do process ALL the smiles first which could be slow
return np.array(X_out).reshape(-1, 1)

@check_transform_input
def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.?
def inverse_transform(self, X_mols_list, y=None):
X_out = []

for mol in X_mols_list:
smiles = Chem.MolToSmiles(mol)
X_out.append(smiles)
if isinstance(mol, Chem.Mol):
try:
smiles = Chem.MolToSmiles(mol)
X_out.append(smiles)
except Exception as e:
X_out.append(
InvalidMol(
str(self), f"Error converting Mol to SMILES: {str(e)}"
)
)
else:
X_out.append(InvalidMol(str(self), f"Not a Mol: {mol}"))

if not self.safe_inference_mode and not all(isinstance(x, str) for x in X_out):
fails = [x for x in X_out if not isinstance(x, str)]
raise ValueError(f"Invalid Mols found: {fails}.")

return np.array(X_out).reshape(-1,1)
return np.array(X_out).reshape(-1, 1)
Loading
Loading