diff --git a/README.md b/README.md index 927e41c..453129a 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ There are a collection of notebooks in the notebooks directory which demonstrate - [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb) - [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb) - [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb) +- [Working with pipelines and estimators in safe inference mode for handling prediction on batches with invalid smiles or molecules](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb) We also put a software note on ChemRxiv. [https://doi.org/10.26434/chemrxiv-2023-fzqwd](https://doi.org/10.26434/chemrxiv-2023-fzqwd) diff --git a/notebooks/11_safe_inference.ipynb b/notebooks/11_safe_inference.ipynb new file mode 100644 index 0000000..6ee786e --- /dev/null +++ b/notebooks/11_safe_inference.ipynb @@ -0,0 +1,1023 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Safe inference mode\n", + "\n", + "I think everyone which have worked with SMILES and RDKit sooner or later come across a SMILES that doesn't parse. It can happen if the SMILES was produced with a different toolkit that are less strict with e.g. valence rules, or maybe a characher was missing in the copying from the email. During curation of the dataset for training models, these SMILES need to be identfied and eventually fixed or removed. But what happens when we are finished with our modelling? What kind of molecules and SMILES will a user of the model send for the model in the future when it's in deployment. What kind of SMILES will a generative model create that we need to predict? We don't know and we won't know. So it's kind of crucial to be able to handle these situations. Scikit-Learn models usually simply explodes the entire batch that are being predicted. This is where safe_inference_mode was introduced in Scikit-Mol. With the introduction all transformers got a safe inference mode, where they handle invalid input. How they handle it depends a bit on the transformer, so we will go through the different usual steps and see how things have changed with the introduction of the safe inference mode.\n", + "\n", + "NOTE! In the following demonstration I switch on the safe inference mode individually for demonstration purposes. I would not recommend to do that while building and training models, instead I would switch it on _after_ training and evaluation (more on that later). Otherwise there's a risk to train on the 2% of a dataset that didn't fail....\n", + "\n", + "First some imports and test SMILES and molecules." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[],\n", + " [],\n", + " [],\n", + " [],\n", + " [InvalidMol('SmilesToMolTransformer(safe_inference_mode=True)', error='Invalid Molecule: Explicit valence for atom # 0 N, 4, is greater than permitted')],\n", + " [InvalidMol('SmilesToMolTransformer(safe_inference_mode=True)', error='Invalid SMILES: I'm not a SMILES')]],\n", + " dtype=object)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from rdkit import Chem\n", + "from scikit_mol.conversions import SmilesToMolTransformer\n", + "\n", + "#We have some deprecation warnings, we are adressing them, but they just distract from this demonstration\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning) \n", + "\n", + "smiles = [\"C1=CC=C(C=C1)F\", \"C1=CC=C(C=C1)O\", \"C1=CC=C(C=C1)N\", \"C1=CC=C(C=C1)Cl\"]\n", + "smiles_with_invalid = smiles + [\"N(C)(C)(C)C\", \"I'm not a SMILES\"]\n", + "\n", + "smi2mol = SmilesToMolTransformer(safe_inference_mode=True)\n", + "\n", + "mols_with_invalid = smi2mol.transform(smiles_with_invalid)\n", + "mols_with_invalid" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Without the safe inference mode, the transformation would simply fail, but now we get the expected array back with our RDKit molecules and a last entry which is an object of the type InvalidMol. InvalidMol is simply a placeholder that tells what step failed the conversion and the error. InvalidMol evaluates to `False` in boolean contexts, so it gets easy to filter away and handle in `if`s and list comprehensions. As example:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([], dtype=object),\n", + " array([], dtype=object),\n", + " array([], dtype=object),\n", + " array([], dtype=object)]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[mol for mol in mols_with_invalid if mol]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "or" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([,\n", + " ,\n", + " ,\n", + " ], dtype=object)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask = mols_with_invalid.astype(bool)\n", + "mols_with_invalid[mask]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Having a failsafe SmilesToMol conversion leads us to next step, featurization. The transformers in safe inference mode now return a NumPy masked array instead of a regular NumPy array. It simply evaluates the incoming mols in a boolean context, so e.g. `None`, `np.nan` and other Python objects that evaluates to False will also get masked (i.e. if you use a dataframe with an ROMol column produced with the PandasTools utility)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/plain": [ + "masked_array(\n", + " data=[[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1,\n", + " 0, 1, 1, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1,\n", + " 0, 0, 1, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1,\n", + " 0, 0, 0, 0],\n", + " [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1,\n", + " 0, 1, 0, 1],\n", + " [--, --, --, --, --, --, --, --, --, --, --, --, --, --, --, --,\n", + " --, --, --, --, --, --, --, --, --],\n", + " [--, --, --, --, --, --, --, --, --, --, --, --, --, --, --, --,\n", + " --, --, --, --, --, --, --, --, --]],\n", + " mask=[[False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False],\n", + " [ True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True],\n", + " [ True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True]],\n", + " fill_value=999999,\n", + " dtype=int8)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from scikit_mol.fingerprints import MorganFingerprintTransformer\n", + "\n", + "mfp = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True)\n", + "fps = mfp.transform(mols_with_invalid)\n", + "fps\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, currently scikit-learn models accepts masked arrays, but they do not respect the mask! So if you fed it directly to the model to train, it would seemingly work, but the invalid samples would all have the fill_value, meaning you could get weird results. Instead we need the last part of the puzzle, the SafeInferenceWrapper class." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/esben/git/scikit-mol/scikit_mol/safeinference.py:49: UserWarning: SafeInferenceWrapper is in safe_inference_mode during use of fit and invalid data detected. This mode is intended for safe inference in production, not for training and evaluation.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "array([ 0., 1., 0., 1., nan, nan])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from scikit_mol.safeinference import SafeInferenceWrapper\n", + "from sklearn.linear_model import LogisticRegression\n", + "import numpy as np\n", + "\n", + "regressor = LogisticRegression()\n", + "wrapper = SafeInferenceWrapper(regressor, safe_inference_mode=True)\n", + "wrapper.fit(fps, [0,1,0,1,0,1])\n", + "wrapper.predict(fps)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The prediction went fine both in fit and in prediction, where the result shows `nan` for the invalid entries. However, please note fit in sage_inference_mode is not recommended in a training session, but you are warned and not blocked, because maybe you know what you do and do it on purpose.\n", + "The SafeInferenceMapper both handles rows that are masked in masked arrays, but also checks rows for nonfinite values and filters these away. Sometimes some descriptors may return a inf or nan, even though the molecule itself is valid. The masking of nonfinite values can be switched off, maybe you are using a model that can handle missing data and only want to filter away invalid molecules.\n", + "\n", + "## Setting safe_inference_mode post-training\n", + "As I said before I believe in catching errors and fixing those during training, but what do we do when we need to switch on safe inference mode for all objects in a pipeline? There's of course a tool for that, so lets demo that:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Without safe inference mode:\n", + "Prediction failed with exception: Invalid input found: [InvalidMol('SmilesToMolTransformer()', error='Invalid Molecule: Explicit valence for atom # 0 N, 4, is greater than permitted'), InvalidMol('SmilesToMolTransformer()', error='Invalid SMILES: I'm not a SMILES')].\n", + "\n", + "With safe inference mode:\n", + "[ 1. 0. 1. 0. nan nan]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + } + ], + "source": [ + "from scikit_mol.safeinference import set_safe_inference_mode\n", + "from sklearn.pipeline import Pipeline\n", + "\n", + "pipe = Pipeline([\n", + " (\"smi2mol\", SmilesToMolTransformer()),\n", + " (\"mfp\", MorganFingerprintTransformer(radius=2, nBits=25)),\n", + " (\"safe_regressor\", SafeInferenceWrapper(LogisticRegression()))\n", + "])\n", + "\n", + "pipe.fit(smiles, [1,0,1,0])\n", + "\n", + "print(\"Without safe inference mode:\")\n", + "try:\n", + " pipe.predict(smiles_with_invalid)\n", + "except Exception as e:\n", + " print(\"Prediction failed with exception: \", e)\n", + "print()\n", + "\n", + "set_safe_inference_mode(pipe, True)\n", + "\n", + "print(\"With safe inference mode:\")\n", + "print(pipe.predict(smiles_with_invalid))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the prediction fail without safe inference mode, and proceeds when it's conveniently set by the `set_safe_inference_mode` utility. The model is now ready for save and reuse in a more failsafe manner :-)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining safe_inference_mode with pandas output\n", + "One potential issue can happen when we combine the safe_inference_mode with Pandas output mode of the transformers. It will work, but depending on the batch something surprising can happen due to the way that Pandas converts masked Numpy arrays. Let me demonstrate the issue, first we predict a batch without any errors." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fp_morgan_1fp_morgan_2fp_morgan_3fp_morgan_4fp_morgan_5fp_morgan_6fp_morgan_7fp_morgan_8fp_morgan_9fp_morgan_10...fp_morgan_16fp_morgan_17fp_morgan_18fp_morgan_19fp_morgan_20fp_morgan_21fp_morgan_22fp_morgan_23fp_morgan_24fp_morgan_25
00000000011...0101110110
10000000111...0100110010
20000000011...0101110000
31000000011...0100110101
\n", + "

4 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " fp_morgan_1 fp_morgan_2 fp_morgan_3 fp_morgan_4 fp_morgan_5 \\\n", + "0 0 0 0 0 0 \n", + "1 0 0 0 0 0 \n", + "2 0 0 0 0 0 \n", + "3 1 0 0 0 0 \n", + "\n", + " fp_morgan_6 fp_morgan_7 fp_morgan_8 fp_morgan_9 fp_morgan_10 ... \\\n", + "0 0 0 0 1 1 ... \n", + "1 0 0 1 1 1 ... \n", + "2 0 0 0 1 1 ... \n", + "3 0 0 0 1 1 ... \n", + "\n", + " fp_morgan_16 fp_morgan_17 fp_morgan_18 fp_morgan_19 fp_morgan_20 \\\n", + "0 0 1 0 1 1 \n", + "1 0 1 0 0 1 \n", + "2 0 1 0 1 1 \n", + "3 0 1 0 0 1 \n", + "\n", + " fp_morgan_21 fp_morgan_22 fp_morgan_23 fp_morgan_24 fp_morgan_25 \n", + "0 1 0 1 1 0 \n", + "1 1 0 0 1 0 \n", + "2 1 0 0 0 0 \n", + "3 1 0 1 0 1 \n", + "\n", + "[4 rows x 25 columns]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mfp.set_output(transform=\"pandas\")\n", + "\n", + "mols = smi2mol.transform(smiles)\n", + "\n", + "fps = mfp.transform(mols)\n", + "fps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then lets see if we transform a batch with an invalid molecule:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fp_morgan_1fp_morgan_2fp_morgan_3fp_morgan_4fp_morgan_5fp_morgan_6fp_morgan_7fp_morgan_8fp_morgan_9fp_morgan_10...fp_morgan_16fp_morgan_17fp_morgan_18fp_morgan_19fp_morgan_20fp_morgan_21fp_morgan_22fp_morgan_23fp_morgan_24fp_morgan_25
00.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.01.01.00.0
10.00.00.00.00.00.00.01.01.01.0...0.01.00.00.01.01.00.00.01.00.0
20.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.00.00.00.0
31.00.00.00.00.00.00.00.01.01.0...0.01.00.00.01.01.00.01.00.01.0
4NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
5NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", + "

6 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " fp_morgan_1 fp_morgan_2 fp_morgan_3 fp_morgan_4 fp_morgan_5 \\\n", + "0 0.0 0.0 0.0 0.0 0.0 \n", + "1 0.0 0.0 0.0 0.0 0.0 \n", + "2 0.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 0.0 0.0 0.0 \n", + "4 NaN NaN NaN NaN NaN \n", + "5 NaN NaN NaN NaN NaN \n", + "\n", + " fp_morgan_6 fp_morgan_7 fp_morgan_8 fp_morgan_9 fp_morgan_10 ... \\\n", + "0 0.0 0.0 0.0 1.0 1.0 ... \n", + "1 0.0 0.0 1.0 1.0 1.0 ... \n", + "2 0.0 0.0 0.0 1.0 1.0 ... \n", + "3 0.0 0.0 0.0 1.0 1.0 ... \n", + "4 NaN NaN NaN NaN NaN ... \n", + "5 NaN NaN NaN NaN NaN ... \n", + "\n", + " fp_morgan_16 fp_morgan_17 fp_morgan_18 fp_morgan_19 fp_morgan_20 \\\n", + "0 0.0 1.0 0.0 1.0 1.0 \n", + "1 0.0 1.0 0.0 0.0 1.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 0.0 1.0 \n", + "4 NaN NaN NaN NaN NaN \n", + "5 NaN NaN NaN NaN NaN \n", + "\n", + " fp_morgan_21 fp_morgan_22 fp_morgan_23 fp_morgan_24 fp_morgan_25 \n", + "0 1.0 0.0 1.0 1.0 0.0 \n", + "1 1.0 0.0 0.0 1.0 0.0 \n", + "2 1.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 1.0 0.0 1.0 \n", + "4 NaN NaN NaN NaN NaN \n", + "5 NaN NaN NaN NaN NaN \n", + "\n", + "[6 rows x 25 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fps = mfp.transform(mols_with_invalid)\n", + "fps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The second output is no longer integers, but floats. As most sklearn models cast input arrays to float32 internally, this difference is likely benign, but that's not guaranteed! Thus if you want to use pandas output for your production models, do check that the final outputs are the same for the valid rows, with and without a single invalid row. Alternatively the dtype for the output of the transformer can be switched to float for consistency." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n", + "[17:02:50] DEPRECATION WARNING: please use MorganGenerator\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fp_morgan_1fp_morgan_2fp_morgan_3fp_morgan_4fp_morgan_5fp_morgan_6fp_morgan_7fp_morgan_8fp_morgan_9fp_morgan_10...fp_morgan_16fp_morgan_17fp_morgan_18fp_morgan_19fp_morgan_20fp_morgan_21fp_morgan_22fp_morgan_23fp_morgan_24fp_morgan_25
00.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.01.01.00.0
10.00.00.00.00.00.00.01.01.01.0...0.01.00.00.01.01.00.00.01.00.0
20.00.00.00.00.00.00.00.01.01.0...0.01.00.01.01.01.00.00.00.00.0
31.00.00.00.00.00.00.00.01.01.0...0.01.00.00.01.01.00.01.00.01.0
\n", + "

4 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " fp_morgan_1 fp_morgan_2 fp_morgan_3 fp_morgan_4 fp_morgan_5 \\\n", + "0 0.0 0.0 0.0 0.0 0.0 \n", + "1 0.0 0.0 0.0 0.0 0.0 \n", + "2 0.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 0.0 0.0 0.0 \n", + "\n", + " fp_morgan_6 fp_morgan_7 fp_morgan_8 fp_morgan_9 fp_morgan_10 ... \\\n", + "0 0.0 0.0 0.0 1.0 1.0 ... \n", + "1 0.0 0.0 1.0 1.0 1.0 ... \n", + "2 0.0 0.0 0.0 1.0 1.0 ... \n", + "3 0.0 0.0 0.0 1.0 1.0 ... \n", + "\n", + " fp_morgan_16 fp_morgan_17 fp_morgan_18 fp_morgan_19 fp_morgan_20 \\\n", + "0 0.0 1.0 0.0 1.0 1.0 \n", + "1 0.0 1.0 0.0 0.0 1.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 0.0 1.0 \n", + "\n", + " fp_morgan_21 fp_morgan_22 fp_morgan_23 fp_morgan_24 fp_morgan_25 \n", + "0 1.0 0.0 1.0 1.0 0.0 \n", + "1 1.0 0.0 0.0 1.0 0.0 \n", + "2 1.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + "[4 rows x 25 columns]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mfp_float = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True, dtype=np.float32)\n", + "mfp_float.set_output(transform=\"pandas\")\n", + "fps = mfp_float.transform(mols)\n", + "fps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I hope this new feature of Scikit-Mol will make it even easier to handle models, even when used in environments without SMILES or molecule validity guarantees." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "vscode", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/11_safe_inference.py b/notebooks/11_safe_inference.py new file mode 100644 index 0000000..83d4d99 --- /dev/null +++ b/notebooks/11_safe_inference.py @@ -0,0 +1,145 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.1 +# kernelspec: +# display_name: vscode +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Safe inference mode +# +# I think everyone which have worked with SMILES and RDKit sooner or later come across a SMILES that doesn't parse. It can happen if the SMILES was produced with a different toolkit that are less strict with e.g. valence rules, or maybe a characher was missing in the copying from the email. During curation of the dataset for training models, these SMILES need to be identfied and eventually fixed or removed. But what happens when we are finished with our modelling? What kind of molecules and SMILES will a user of the model send for the model in the future when it's in deployment. What kind of SMILES will a generative model create that we need to predict? We don't know and we won't know. So it's kind of crucial to be able to handle these situations. Scikit-Learn models usually simply explodes the entire batch that are being predicted. This is where safe_inference_mode was introduced in Scikit-Mol. With the introduction all transformers got a safe inference mode, where they handle invalid input. How they handle it depends a bit on the transformer, so we will go through the different usual steps and see how things have changed with the introduction of the safe inference mode. +# +# NOTE! In the following demonstration I switch on the safe inference mode individually for demonstration purposes. I would not recommend to do that while building and training models, instead I would switch it on _after_ training and evaluation (more on that later). Otherwise there's a risk to train on the 2% of a dataset that didn't fail.... +# +# First some imports and test SMILES and molecules. + +# %% +from rdkit import Chem +from scikit_mol.conversions import SmilesToMolTransformer + +#We have some deprecation warnings, we are adressing them, but they just distract from this demonstration +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) + +smiles = ["C1=CC=C(C=C1)F", "C1=CC=C(C=C1)O", "C1=CC=C(C=C1)N", "C1=CC=C(C=C1)Cl"] +smiles_with_invalid = smiles + ["N(C)(C)(C)C", "I'm not a SMILES"] + +smi2mol = SmilesToMolTransformer(safe_inference_mode=True) + +mols_with_invalid = smi2mol.transform(smiles_with_invalid) +mols_with_invalid + +# %% [markdown] +# Without the safe inference mode, the transformation would simply fail, but now we get the expected array back with our RDKit molecules and a last entry which is an object of the type InvalidMol. InvalidMol is simply a placeholder that tells what step failed the conversion and the error. InvalidMol evaluates to `False` in boolean contexts, so it gets easy to filter away and handle in `if`s and list comprehensions. As example: + +# %% +[mol for mol in mols_with_invalid if mol] + +# %% [markdown] +# or + +# %% +mask = mols_with_invalid.astype(bool) +mols_with_invalid[mask] + +# %% [markdown] +# Having a failsafe SmilesToMol conversion leads us to next step, featurization. The transformers in safe inference mode now return a NumPy masked array instead of a regular NumPy array. It simply evaluates the incoming mols in a boolean context, so e.g. `None`, `np.nan` and other Python objects that evaluates to False will also get masked (i.e. if you use a dataframe with an ROMol column produced with the PandasTools utility) + +# %% +from scikit_mol.fingerprints import MorganFingerprintTransformer + +mfp = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True) +fps = mfp.transform(mols_with_invalid) +fps + + +# %% [markdown] +# However, currently scikit-learn models accepts masked arrays, but they do not respect the mask! So if you fed it directly to the model to train, it would seemingly work, but the invalid samples would all have the fill_value, meaning you could get weird results. Instead we need the last part of the puzzle, the SafeInferenceWrapper class. + +# %% +from scikit_mol.safeinference import SafeInferenceWrapper +from sklearn.linear_model import LogisticRegression +import numpy as np + +regressor = LogisticRegression() +wrapper = SafeInferenceWrapper(regressor, safe_inference_mode=True) +wrapper.fit(fps, [0,1,0,1,0,1]) +wrapper.predict(fps) + + +# %% [markdown] +# + +# %% [markdown] +# The prediction went fine both in fit and in prediction, where the result shows `nan` for the invalid entries. However, please note fit in sage_inference_mode is not recommended in a training session, but you are warned and not blocked, because maybe you know what you do and do it on purpose. +# The SafeInferenceMapper both handles rows that are masked in masked arrays, but also checks rows for nonfinite values and filters these away. Sometimes some descriptors may return a inf or nan, even though the molecule itself is valid. The masking of nonfinite values can be switched off, maybe you are using a model that can handle missing data and only want to filter away invalid molecules. +# +# ## Setting safe_inference_mode post-training +# As I said before I believe in catching errors and fixing those during training, but what do we do when we need to switch on safe inference mode for all objects in a pipeline? There's of course a tool for that, so lets demo that: + +# %% +from scikit_mol.safeinference import set_safe_inference_mode +from sklearn.pipeline import Pipeline + +pipe = Pipeline([ + ("smi2mol", SmilesToMolTransformer()), + ("mfp", MorganFingerprintTransformer(radius=2, nBits=25)), + ("safe_regressor", SafeInferenceWrapper(LogisticRegression())) +]) + +pipe.fit(smiles, [1,0,1,0]) + +print("Without safe inference mode:") +try: + pipe.predict(smiles_with_invalid) +except Exception as e: + print("Prediction failed with exception: ", e) +print() + +set_safe_inference_mode(pipe, True) + +print("With safe inference mode:") +print(pipe.predict(smiles_with_invalid)) + +# %% [markdown] +# We see that the prediction fail without safe inference mode, and proceeds when it's conveniently set by the `set_safe_inference_mode` utility. The model is now ready for save and reuse in a more failsafe manner :-) + +# %% [markdown] +# ## Combining safe_inference_mode with pandas output +# One potential issue can happen when we combine the safe_inference_mode with Pandas output mode of the transformers. It will work, but depending on the batch something surprising can happen due to the way that Pandas converts masked Numpy arrays. Let me demonstrate the issue, first we predict a batch without any errors. + +# %% +mfp.set_output(transform="pandas") + +mols = smi2mol.transform(smiles) + +fps = mfp.transform(mols) +fps + +# %% [markdown] +# Then lets see if we transform a batch with an invalid molecule: + +# %% +fps = mfp.transform(mols_with_invalid) +fps + +# %% [markdown] +# The second output is no longer integers, but floats. As most sklearn models cast input arrays to float32 internally, this difference is likely benign, but that's not guaranteed! Thus if you want to use pandas output for your production models, do check that the final outputs are the same for the valid rows, with and without a single invalid row. Alternatively the dtype for the output of the transformer can be switched to float for consistency. + +# %% +mfp_float = MorganFingerprintTransformer(radius=2, nBits=25, safe_inference_mode=True, dtype=np.float32) +mfp_float.set_output(transform="pandas") +fps = mfp_float.transform(mols) +fps + +# %% [markdown] +# I hope this new feature of Scikit-Mol will make it even easier to handle models, even when used in environments without SMILES or molecule validity guarantees. diff --git a/notebooks/README.md b/notebooks/README.md index 8b0ec12..b744709 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -14,3 +14,4 @@ This is a collection of notebooks in the notebooks directory which demonstrates - [Using skopt for hyperparameter tuning](https://github.com/EBjerrum/scikit-mol/tree/main/notebooks/08_external_library_skopt.ipynb) - [Testing different fingerprints as part of the hyperparameter optimization](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/09_Combinatorial_Method_Usage_with_FingerPrint_Transformers.ipynb) - [Using pandas output for easy feature importance analysis and combine pre-exisitng values with new computations](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/10_pipeline_pandas_output.ipynb) +- [Working with pipelines and estimators in safe inference mode](https://github.com/EBjerrum/scikit-mol/blob/main/notebooks/11_safe_inference.ipynb) diff --git a/scikit_mol/conversions.py b/scikit_mol/conversions.py index 1c75ba5..450ab31 100644 --- a/scikit_mol/conversions.py +++ b/scikit_mol/conversions.py @@ -2,18 +2,44 @@ import multiprocessing from typing import Union from rdkit import Chem +from rdkit.rdBase import BlockLogs import numpy as np from sklearn.base import BaseEstimator, TransformerMixin -from scikit_mol.core import check_transform_input, feature_names_default_mol ,DEFAULT_MOL_COLUMN_NAME +from scikit_mol.core import ( + check_transform_input, + feature_names_default_mol, + DEFAULT_MOL_COLUMN_NAME, + InvalidMol, +) +# from scikit_mol._invalid import InvalidMol -class SmilesToMolTransformer(BaseEstimator, TransformerMixin): - def __init__(self, parallel: Union[bool, int] = False): +class SmilesToMolTransformer(BaseEstimator, TransformerMixin): + """ + Transformer for converting SMILES strings to RDKit mol objects. + + This transformer can be included in pipelines during development and training, + but the safe inference mode should only be enabled when deploying models for + inference in production environments. + + Parameters: + ----------- + parallel : Union[bool, int], default=False + If True or int > 1, enables parallel processing. + safe_inference_mode : bool, default=False + If True, enables safeguards for handling invalid data during inference. + This should only be set to True when deploying models to production. + """ + + def __init__( + self, parallel: Union[bool, int] = False, safe_inference_mode: bool = False + ): self.parallel = parallel - self.start_method = None #TODO implement handling of start_method + self.start_method = None # TODO implement handling of start_method + self.safe_inference_mode = safe_inference_mode @feature_names_default_mol def get_feature_names_out(self, input_features=None): @@ -39,39 +65,73 @@ def transform(self, X_smiles_list, y=None): Raises ------ ValueError - Raises ValueError if a SMILES string is unparsable by RDKit + Raises ValueError if a SMILES string is unparsable by RDKit and safe_inference_mode is False """ - if not self.parallel: return self._transform(X_smiles_list) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes*2 if n_processes is not None else multiprocessing.cpu_count()*2 #TODO, tune the number of chunks per child process + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes * 2 + if n_processes is not None + else multiprocessing.cpu_count() * 2 + ) # TODO, tune the number of chunks per child process with get_context(self.start_method).Pool(processes=n_processes) as pool: - x_chunks = np.array_split(X_smiles_list, n_chunks) - arrays = pool.map(self._transform, x_chunks) #is the helper function a safer way of handling the picklind and child process communication - arr = np.concatenate(arrays) - return arr + x_chunks = np.array_split(X_smiles_list, n_chunks) + arrays = pool.map( + self._transform, x_chunks + ) # is the helper function a safer way of handling the picklind and child process communication + arr = np.concatenate(arrays) + return arr @check_transform_input def _transform(self, X): X_out = [] - for smiles in X: - mol = Chem.MolFromSmiles(smiles) - if mol: - X_out.append(mol) - else: - raise ValueError(f'Issue with parsing SMILES {smiles}\nYou probably should use the scikit-mol.sanitizer.Sanitizer on your dataset first') - - return np.array(X_out).reshape(-1,1) + with BlockLogs(): + for smiles in X: + mol = Chem.MolFromSmiles(smiles, sanitize=False) + if mol: + errors = Chem.DetectChemistryProblems(mol) + if errors: + error_message = "\n".join(error.Message() for error in errors) + message = f"Invalid Molecule: {error_message}" + X_out.append(InvalidMol(str(self), message)) + else: + Chem.SanitizeMol(mol) + X_out.append(mol) + else: + message = f"Invalid SMILES: {smiles}" + X_out.append(InvalidMol(str(self), message)) + if not self.safe_inference_mode and not all(X_out): + fails = [x for x in X_out if not x] + raise ValueError( + f"Invalid input found: {fails}." + ) # TODO with this approach we get all errors, but we do process ALL the smiles first which could be slow + return np.array(X_out).reshape(-1, 1) @check_transform_input - def inverse_transform(self, X_mols_list, y=None): #TODO, maybe the inverse transform should be configurable e.g. isomericSmiles etc.? + def inverse_transform(self, X_mols_list, y=None): X_out = [] for mol in X_mols_list: - smiles = Chem.MolToSmiles(mol) - X_out.append(smiles) + if isinstance(mol, Chem.Mol): + try: + smiles = Chem.MolToSmiles(mol) + X_out.append(smiles) + except Exception as e: + X_out.append( + InvalidMol( + str(self), f"Error converting Mol to SMILES: {str(e)}" + ) + ) + else: + X_out.append(InvalidMol(str(self), f"Not a Mol: {mol}")) + + if not self.safe_inference_mode and not all(isinstance(x, str) for x in X_out): + fails = [x for x in X_out if not isinstance(x, str)] + raise ValueError(f"Invalid Mols found: {fails}.") - return np.array(X_out).reshape(-1,1) + return np.array(X_out).reshape(-1, 1) diff --git a/scikit_mol/core.py b/scikit_mol/core.py index 66685a6..9b13680 100644 --- a/scikit_mol/core.py +++ b/scikit_mol/core.py @@ -5,6 +5,7 @@ Users who want to create their own transformers should use this module. """ +from dataclasses import dataclass import functools import numpy as np @@ -16,48 +17,71 @@ DEFAULT_MOL_COLUMN_NAME = "ROMol" +@dataclass +class InvalidMol: + """ + Represents molecules which raised an error during a pipeline step. + Evaluates to False in boolean contexts. + """ + + pipeline_step: str + error: str + + def __bool__(self): + return False + + def __repr__(self): + return f"InvalidMol('{self.pipeline_step}', error='{self.error}')" + + def _validate_transform_input(X): - """Validate and adapt the input of the _transform method""" - try: - shape = X.shape - except AttributeError: - # If X is not array-like or dataframe-like, - # we just return it as is, so users can use simple lists and sequences. - return X - # If X is an array-like or dataframe-like, we make sure it is compatible with - # the scikit-learn API, and that it contains a single column: - # scikit-mol transformers need a single column with smiles or mols. - if len(shape) == 1: - return X # Flatt Arrays and list-like data are also supported #TODO, add a warning about non-2D data if logging is implemented - if shape[1] != 1: - raise ValueError("Only one column supported. You may want to use a ColumnTransformer https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html ") - return np.array(X).flatten() + """Validate and adapt the input of the _transform method""" + try: + shape = X.shape + except AttributeError: + # If X is not array-like or dataframe-like, + # we just return it as is, so users can use simple lists and sequences. + return X + # If X is an array-like or dataframe-like, we make sure it is compatible with + # the scikit-learn API, and that it contains a single column: + # scikit-mol transformers need a single column with smiles or mols. + if len(shape) == 1: + return X # Flatt Arrays and list-like data are also supported #TODO, add a warning about non-2D data if logging is implemented + if shape[1] != 1: + raise ValueError( + "Only one column supported. You may want to use a ColumnTransformer https://scikit-learn.org/stable/modules/generated/sklearn.compose.ColumnTransformer.html " + ) + return np.array(X).flatten() + def check_transform_input(method): """ Decorator to check the input of the _transform method and make it compatible with the scikit-learn API and with downstream methods. """ + @functools.wraps(method) def wrapper(obj, X): X = _validate_transform_input(X) - result = method(obj, X) + result = method(obj, X) # If the output of the _transform method # must be changed depending on the initial type of X, do it here. return result return wrapper + def feature_names_default_mol(method): """ Decorator that returns the default feature names for the mol object """ + @functools.wraps(method) def wrapper(obj, input_features=None): prefix = DEFAULT_MOL_COLUMN_NAME if input_features is not None: - return np.array([f'{prefix}_{name}' for name in input_features]) + return np.array([f"{prefix}_{name}" for name in input_features]) else: return np.array([prefix]) - return wrapper \ No newline at end of file + return wrapper diff --git a/scikit_mol/descriptors.py b/scikit_mol/descriptors.py index a516fc5..905a098 100644 --- a/scikit_mol/descriptors.py +++ b/scikit_mol/descriptors.py @@ -12,10 +12,9 @@ from scikit_mol.core import check_transform_input - class MolecularDescriptorTransformer(BaseEstimator, TransformerMixin): """Descriptor calculation transformer - + Parameters ---------- desc_list : (List of descriptor names) @@ -23,9 +22,12 @@ class MolecularDescriptorTransformer(BaseEstimator, TransformerMixin): parallel : boolean, int if True, multiprocessing will be used. If set to an int > 1, that specified number of processes will be used, otherwise it's autodetected. - start_method : str + start_method : str The method to start child processes when parallel=True. can be 'fork', 'spawn' or 'forkserver'. If None, the OS and Pythons default will be used. + safe_inference_mode : bool + If True, enables safeguards for handling invalid data during inference. + This should only be set to True when deploying models to production. Returns ------- @@ -34,14 +36,20 @@ class MolecularDescriptorTransformer(BaseEstimator, TransformerMixin): """ + def __init__( - self, desc_list: Optional[str] = None, + self, + desc_list: Optional[str] = None, parallel: Union[bool, int] = False, - start_method: str = None#"fork" - ): + start_method: str = None, # "fork", + safe_inference_mode: bool = False, + dtype: np.dtype = np.float32, + ): self.desc_list = desc_list self.parallel = parallel self.start_method = start_method + self.safe_inference_mode = safe_inference_mode + self.dtype = dtype def _get_desc_calculator(self) -> MolecularDescriptorCalculator: if self.desc_list: @@ -50,9 +58,7 @@ def _get_desc_calculator(self) -> MolecularDescriptorCalculator: for desc_name in self.desc_list if desc_name not in self.available_descriptors ] - assert ( - not unknown_descriptors - ), f"Unknown descriptor names {unknown_descriptors} specified, please check available_descriptors property\nPlease check availble list {self.available_descriptors}" + assert not unknown_descriptors, f"Unknown descriptor names {unknown_descriptors} specified, please check available_descriptors property\nPlease check availble list {self.available_descriptors}" else: self.desc_list = self.available_descriptors return MolecularDescriptorCalculator(self.desc_list) @@ -89,24 +95,41 @@ def start_method(self, start_method): """Allowed methods are spawn, fork and forkserver on MacOS and Linux, only spawn is possible on Windows. None will choose the default for the OS and version of Python.""" allowed_start_methods = ["spawn", "fork", "forkserver", None] - assert start_method in allowed_start_methods, f"start_method not in allowed methods {allowed_start_methods}" + assert ( + start_method in allowed_start_methods + ), f"start_method not in allowed methods {allowed_start_methods}" self._start_method = start_method - def _transform_mol(self, mol: Mol) -> List[Any]: - return list(self.calculators.CalcDescriptors(mol)) + def _transform_mol(self, mol: Mol) -> Union[np.ndarray, np.ma.MaskedArray]: + if not mol: + if self.safe_inference_mode: + return np.ma.masked_all(len(self.desc_list)) + else: + raise ValueError(f"Invalid molecule provided: {mol}") + try: + return np.array(list(self.calculators.CalcDescriptors(mol))) + except Exception as e: + if self.safe_inference_mode: + return np.ma.masked_all(len(self.desc_list)) + else: + raise e def fit(self, x, y=None): """Included for scikit-learn compatibility, does nothing""" return self @check_transform_input - def _transform(self, x: List[Mol]) -> np.ndarray: - arr = np.zeros((len(x), len(self.desc_list))) - for i, mol in enumerate(x): - arr[i, :] = self._transform_mol(mol) - return arr + def _transform(self, x: List[Mol]) -> Union[np.ndarray, np.ma.MaskedArray]: + if self.safe_inference_mode: + arrays = [self._transform_mol(mol) for mol in x] + return np.ma.array(arrays, dtype=self.dtype) + else: + arr = np.zeros((len(x), len(self.desc_list)), dtype=self.dtype) + for i, mol in enumerate(x): + arr[i, :] = self._transform_mol(mol) + return arr - def transform(self, x: List[Mol], y=None) -> np.ndarray: + def transform(self, x: List[Mol], y=None) -> Union[np.ndarray, np.ma.MaskedArray]: """Transform a list of molecules into an array of descriptor values Parameters ---------- @@ -117,34 +140,39 @@ def transform(self, x: List[Mol], y=None) -> np.ndarray: Returns ------- - np.array - Descriptors, shape (samples, length of .selected_descriptors ) - + Union[np.ndarray, np.ma.MaskedArray] + Descriptors, shape (samples, length of .selected_descriptors) + """ if not self.parallel: return self._transform(x) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes if n_processes is not None else multiprocessing.cpu_count() #TODO, tune the number of chunks per child process - + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes if n_processes is not None else multiprocessing.cpu_count() + ) # TODO, tune the number of chunks per child process + with get_context(self.start_method).Pool(processes=n_processes) as pool: params = self.get_params() - x_chunks = np.array_split(x, n_chunks) - #x_chunks = [x.reshape(-1, 1) for x in x_chunks] - arrays = pool.map(parallel_helper, [(params, x) for x in x_chunks]) #is the helper function a safer way of handling the picklind and child process communication - arr = np.concatenate(arrays) + x_chunks = np.array_split(x, n_chunks) + arrays = pool.map(parallel_helper, [(params, x) for x in x_chunks]) + if self.safe_inference_mode: + arr = np.ma.concatenate(arrays) + else: + arr = np.concatenate(arrays) return arr # May be safer to instantiate the transformer object in the child process, and only transfer the parameters # There were issues with freezing when using RDKit 2022.3 def parallel_helper(args): - """Will get a tuple with Desc2DTransformer parameters and mols to transform. + """Will get a tuple with Desc2DTransformer parameters and mols to transform. Will then instantiate the transformer and transform the molecules""" from scikit_mol.descriptors import MolecularDescriptorTransformer - + params, mols = args transformer = MolecularDescriptorTransformer(**params) y = transformer._transform(mols) return y - \ No newline at end of file diff --git a/scikit_mol/fingerprints.py b/scikit_mol/fingerprints.py index 767bfc6..f044a06 100644 --- a/scikit_mol/fingerprints.py +++ b/scikit_mol/fingerprints.py @@ -1,11 +1,11 @@ -#%% from multiprocessing import Pool, get_context import multiprocessing import re from typing import Union from rdkit import Chem from rdkit import DataStructs -#from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect + +# from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect from rdkit.Chem import rdMolDescriptors from rdkit.Chem import rdFingerprintGenerator from rdkit.Chem import rdMHFPFingerprint @@ -21,18 +21,24 @@ from abc import ABC, abstractmethod -_PATTERN_FINGERPRINT_TRANSFORMER = re.compile(r"^(?P\w+)FingerprintTransformer$") -#%% -class FpsTransformer(ABC, BaseEstimator, TransformerMixin): +_PATTERN_FINGERPRINT_TRANSFORMER = re.compile( + r"^(?P\w+)FingerprintTransformer$" +) - def __init__(self, parallel: Union[bool, int] = False, start_method: str = None): - self.parallel = parallel - self.start_method = start_method #TODO implement handling of start_method - # The dtype of the fingerprint array computed by the transformer - # If needed this property can be overwritten in the child class. - _DTYPE_FINGERPRINT = np.int8 +class FpsTransformer(ABC, BaseEstimator, TransformerMixin): + def __init__( + self, + parallel: Union[bool, int] = False, + start_method: str = None, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): + self.parallel = parallel + self.start_method = start_method + self.safe_inference_mode = safe_inference_mode + self.dtype = dtype def _get_column_prefix(self) -> str: matched = _PATTERN_FINGERPRINT_TRANSFORMER.match(type(self).__name__) @@ -54,7 +60,9 @@ def get_display_feature_names_out(self, input_features=None): """ prefix = self._get_column_prefix() n_digits = self._get_n_digits_column_suffix() - return np.array([f"{prefix}_{str(i).zfill(n_digits)}" for i in range(1, self.nBits + 1)]) + return np.array( + [f"{prefix}_{str(i).zfill(n_digits)}" for i in range(1, self.nBits + 1)] + ) def get_feature_names_out(self, input_features=None): """Get feature names for fingerprint transformers @@ -67,21 +75,31 @@ def get_feature_names_out(self, input_features=None): @abstractmethod def _mol2fp(self, mol): - """Generate descriptor from mol + """Generate fingerprint from mol MUST BE OVERWRITTEN """ raise NotImplementedError("_mol2fp not implemented") def _fp2array(self, fp): - arr = np.zeros((self.nBits,), dtype=self._DTYPE_FINGERPRINT) - DataStructs.ConvertToNumpyArray(fp, arr) - return arr + if fp: + arr = np.zeros((self.nBits,), dtype=self.dtype) + DataStructs.ConvertToNumpyArray(fp, arr) + return arr + else: + return np.ma.masked_all((self.nBits,), dtype=self.dtype) def _transform_mol(self, mol): - fp = self._mol2fp(mol) - arr = self._fp2array(fp) - return arr + if not mol and self.safe_inference_mode: + return self._fp2array(False) + try: + fp = self._mol2fp(mol) + return self._fp2array(fp) + except Exception as e: + if self.safe_inference_mode: + return self._fp2array(False) + else: + raise e def fit(self, X, y=None): """Included for scikit-learn compatibility @@ -92,16 +110,22 @@ def fit(self, X, y=None): @check_transform_input def _transform(self, X): - arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) - for i, mol in enumerate(X): - arr[i,:] = self._transform_mol(mol) - return arr + if self.safe_inference_mode: + # Use the new method with masked arrays if we're in safe inference mode + arrays = [self._transform_mol(mol) for mol in X] + return np.ma.stack(arrays) + else: + # Use the original, faster method if we're not in safe inference mode + arr = np.zeros((len(X), self.nBits), dtype=self.dtype) + for i, mol in enumerate(X): + arr[i, :] = self._transform_mol(mol) + return arr def _transform_sparse(self, X): - arr = np.zeros((len(X), self.nBits), dtype=self._DTYPE_FINGERPRINT) + arr = np.zeros((len(X), self.nBits), dtype=self.dtype) for i, mol in enumerate(X): - arr[i,:] = self._transform_mol(mol) - + arr[i, :] = self._transform_mol(mol) + return lil_matrix(arr) def transform(self, X, y=None): @@ -123,29 +147,48 @@ def transform(self, X, y=None): return self._transform(X) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes if n_processes is not None else multiprocessing.cpu_count() - + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes if n_processes is not None else multiprocessing.cpu_count() + ) + with get_context(self.start_method).Pool(processes=n_processes) as pool: x_chunks = np.array_split(X, n_chunks) - #TODO check what is fastest, pickle or recreate and do this only for classes that need this - #arrays = pool.map(self._transform, x_chunks) + # TODO check what is fastest, pickle or recreate and do this only for classes that need this + # arrays = pool.map(self._transform, x_chunks) parameters = self.get_params() # TODO: create "transform_parallel" function in the core module, # and use it here and in the descriptors transformer - #x_chunks = [np.array(x).reshape(-1, 1) for x in x_chunks] - arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) - - arr = np.concatenate(arrays) + # x_chunks = [np.array(x).reshape(-1, 1) for x in x_chunks] + arrays = pool.map( + parallel_helper, + [ + (self.__class__.__name__, parameters, x_chunk) + for x_chunk in x_chunks + ], + ) + if self.safe_inference_mode: + arr = np.ma.concatenate(arrays) + else: + arr = np.concatenate(arrays) return arr class MACCSKeysFingerprintTransformer(FpsTransformer): - def __init__(self, parallel: Union[bool, int] = False): + def __init__( + self, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): """MACCS keys fingerprinter calculates the 167 fixed MACCS keys """ - super().__init__(parallel = parallel) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.nBits = 167 @property @@ -155,20 +198,32 @@ def nBits(self): @nBits.setter def nBits(self, nBits): if nBits != 167: - raise ValueError("nBits can only be 167, matching the number of defined MACCS keys!") + raise ValueError( + "nBits can only be 167, matching the number of defined MACCS keys!" + ) self._nBits = nBits def _mol2fp(self, mol): - return rdMolDescriptors.GetMACCSKeysFingerprint( - mol - ) + return rdMolDescriptors.GetMACCSKeysFingerprint(mol) + class RDKitFingerprintTransformer(FpsTransformer): - def __init__(self, minPath:int = 1, maxPath:int =7, useHs:bool = True, branchedPaths:bool = True, - useBondOrder:bool = True, countSimulation:bool = False, countBounds = None, - fpSize:int = 2048, numBitsPerFeature:int = 2, atomInvariantsGenerator = None, - parallel: Union[bool, int] = False - ): + def __init__( + self, + minPath: int = 1, + maxPath: int = 7, + useHs: bool = True, + branchedPaths: bool = True, + useBondOrder: bool = True, + countSimulation: bool = False, + countBounds=None, + fpSize: int = 2048, + numBitsPerFeature: int = 2, + atomInvariantsGenerator=None, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): """Calculates the RDKit fingerprints Parameters @@ -194,7 +249,9 @@ def __init__(self, minPath:int = 1, maxPath:int =7, useHs:bool = True, branchedP atomInvariantsGenerator : _type_, optional atom invariants to be used during fingerprint generation, by default None """ - super().__init__(parallel = parallel) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.minPath = minPath self.maxPath = maxPath self.useHs = useHs @@ -210,27 +267,48 @@ def __init__(self, minPath:int = 1, maxPath:int =7, useHs:bool = True, branchedP def fpSize(self): return self.nBits - #Scikit-Learn expects to be able to set fpSize directly on object via .set_params(), so this updates nBits used by the abstract class + # Scikit-Learn expects to be able to set fpSize directly on object via .set_params(), so this updates nBits used by the abstract class @fpSize.setter def fpSize(self, fpSize): self.nBits = fpSize def _mol2fp(self, mol): - generator = rdFingerprintGenerator.GetRDKitFPGenerator(minPath=int(self.minPath), maxPath=int(self.maxPath), - useHs=bool(self.useHs), branchedPaths=bool(self.branchedPaths), - useBondOrder=bool(self.useBondOrder), - countSimulation=bool(self.countSimulation), - countBounds=bool(self.countBounds), fpSize=int(self.fpSize), - numBitsPerFeature=int(self.numBitsPerFeature), - atomInvariantsGenerator=self.atomInvariantsGenerator - ) + generator = rdFingerprintGenerator.GetRDKitFPGenerator( + minPath=int(self.minPath), + maxPath=int(self.maxPath), + useHs=bool(self.useHs), + branchedPaths=bool(self.branchedPaths), + useBondOrder=bool(self.useBondOrder), + countSimulation=bool(self.countSimulation), + countBounds=bool(self.countBounds), + fpSize=int(self.fpSize), + numBitsPerFeature=int(self.numBitsPerFeature), + atomInvariantsGenerator=self.atomInvariantsGenerator, + ) return generator.GetFingerprint(mol) -class AtomPairFingerprintTransformer(FpsTransformer): #FIXME, some of the init arguments seems to be molecule specific, and should probably not be setable? - def __init__(self, minLength:int = 1, maxLength:int = 30, fromAtoms = 0, ignoreAtoms = 0, atomInvariants = 0, - nBitsPerEntry:int = 4, includeChirality:bool = False, use2D:bool = True, confId:int = -1, nBits=2048, - useCounts:bool=False, parallel: Union[bool, int] = False,): - super().__init__(parallel = parallel) + +class AtomPairFingerprintTransformer(FpsTransformer): + def __init__( + self, + minLength: int = 1, + maxLength: int = 30, + fromAtoms=0, + ignoreAtoms=0, + atomInvariants=0, + nBitsPerEntry: int = 4, + includeChirality: bool = False, + use2D: bool = True, + confId: int = -1, + nBits=2048, + useCounts: bool = False, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.minLength = minLength self.maxLength = maxLength self.fromAtoms = fromAtoms @@ -245,34 +323,52 @@ def __init__(self, minLength:int = 1, maxLength:int = 30, fromAtoms = 0, ignoreA def _mol2fp(self, mol): if self.useCounts: - return rdMolDescriptors.GetHashedAtomPairFingerprint(mol, nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId) - ) + return rdMolDescriptors.GetHashedAtomPairFingerprint( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) else: - return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=int(self.nBits), - minLength=int(self.minLength), - maxLength=int(self.maxLength), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - nBitsPerEntry=int(self.nBitsPerEntry), - includeChirality=bool(self.includeChirality), - use2D=bool(self.use2D), - confId=int(self.confId) - ) + return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect( + mol, + nBits=int(self.nBits), + minLength=int(self.minLength), + maxLength=int(self.maxLength), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + nBitsPerEntry=int(self.nBitsPerEntry), + includeChirality=bool(self.includeChirality), + use2D=bool(self.use2D), + confId=int(self.confId), + ) + class TopologicalTorsionFingerprintTransformer(FpsTransformer): - def __init__(self, targetSize:int = 4, fromAtoms = 0, ignoreAtoms = 0, atomInvariants = 0, - includeChirality:bool = False, nBitsPerEntry:int = 4, nBits=2048, - useCounts:bool=False, parallel: Union[bool, int] = False): - super().__init__(parallel = parallel) + def __init__( + self, + targetSize: int = 4, + fromAtoms=0, + ignoreAtoms=0, + atomInvariants=0, + includeChirality: bool = False, + nBitsPerEntry: int = 4, + nBits=2048, + useCounts: bool = False, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.targetSize = targetSize self.fromAtoms = fromAtoms self.ignoreAtoms = ignoreAtoms @@ -284,46 +380,65 @@ def __init__(self, targetSize:int = 4, fromAtoms = 0, ignoreAtoms = 0, atomInvar def _mol2fp(self, mol): if self.useCounts: - return rdMolDescriptors.GetHashedTopologicalTorsionFingerprint(mol, nBits=int(self.nBits), - targetSize=int(self.targetSize), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - ) + return rdMolDescriptors.GetHashedTopologicalTorsionFingerprint( + mol, + nBits=int(self.nBits), + targetSize=int(self.targetSize), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + ) else: - return rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=int(self.nBits), - targetSize=int(self.targetSize), - fromAtoms=self.fromAtoms, - ignoreAtoms=self.ignoreAtoms, - atomInvariants=self.atomInvariants, - includeChirality=bool(self.includeChirality), - nBitsPerEntry=int(self.nBitsPerEntry) - ) + return rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect( + mol, + nBits=int(self.nBits), + targetSize=int(self.targetSize), + fromAtoms=self.fromAtoms, + ignoreAtoms=self.ignoreAtoms, + atomInvariants=self.atomInvariants, + includeChirality=bool(self.includeChirality), + nBitsPerEntry=int(self.nBitsPerEntry), + ) + class MHFingerprintTransformer(FpsTransformer): - # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 - def __init__(self, radius:int=3, rings:bool=True, isomeric:bool=False, kekulize:bool=False, - min_radius:int=1, n_permutations:int=2048, seed:int=42, parallel: Union[bool, int] = False,): + def __init__( + self, + radius: int = 3, + rings: bool = True, + isomeric: bool = False, + kekulize: bool = False, + min_radius: int = 1, + n_permutations: int = 2048, + seed: int = 42, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int32, + ): """Transforms the RDKit mol into the MinHash fingerprint (MHFP) + https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 + Args: radius (int, optional): The MHFP radius. Defaults to 3. rings (bool, optional): Whether or not to include rings in the shingling. Defaults to True. isomeric (bool, optional): Whether the isomeric SMILES to be considered. Defaults to False. kekulize (bool, optional): Whether or not to kekulize the extracted SMILES. Defaults to False. min_radius (int, optional): The minimum radius that is used to extract n-gram. Defaults to 1. - n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0, + n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0, this is effectively the length of the FP seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel = parallel) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.radius = radius self.rings = rings self.isomeric = isomeric self.kekulize = kekulize self.min_radius = min_radius - #Set the .n_permutations and .seed without creating the encoder twice + # Set the .n_permutations and .seed without creating the encoder twice self._n_permutations = n_permutations self._seed = seed # create the encoder instance @@ -333,7 +448,7 @@ def __getstate__(self): # Get the state of the parent class state = super().__getstate__() # Remove the unpicklable property from the state - state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable + state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable return state def __setstate__(self, state): @@ -342,17 +457,19 @@ def __setstate__(self, state): # Re-create the unpicklable property self._recreate_encoder() - _DTYPE_FINGERPRINT = np.int32 - def _mol2fp(self, mol): - fp = self.mhfp_encoder.EncodeMol(mol, self.radius, self.rings, self.isomeric, self.kekulize, self.min_radius) + fp = self.mhfp_encoder.EncodeMol( + mol, self.radius, self.rings, self.isomeric, self.kekulize, self.min_radius + ) return fp - + def _fp2array(self, fp): return np.array(fp) def _recreate_encoder(self): - self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder(self._n_permutations, self._seed) + self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder( + self._n_permutations, self._seed + ) @property def seed(self): @@ -379,10 +496,23 @@ def nBits(self): # to be compliant with the requirement of the base class return self._n_permutations + class SECFingerprintTransformer(FpsTransformer): # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 - def __init__(self, radius:int=3, rings:bool=True, isomeric:bool=False, kekulize:bool=False, - min_radius:int=1, length:int=2048, n_permutations:int=0, seed:int=0, parallel: Union[bool, int] = False,): + def __init__( + self, + radius: int = 3, + rings: bool = True, + isomeric: bool = False, + kekulize: bool = False, + min_radius: int = 1, + length: int = 2048, + n_permutations: int = 0, + seed: int = 0, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): """Transforms the RDKit mol into the SMILES extended connectivity fingerprint (SECFP) Args: @@ -395,14 +525,16 @@ def __init__(self, radius:int=3, rings:bool=True, isomeric:bool=False, kekulize: n_permutations (int, optional): The number of permutations used for hashing. Defaults to 0. seed (int, optional): The value used to seed numpy.random. Defaults to 0. """ - super().__init__(parallel = parallel) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.radius = radius self.rings = rings self.isomeric = isomeric self.kekulize = kekulize self.min_radius = min_radius self.length = length - #Set the .n_permutations and seed without creating the encoder twice + # Set the .n_permutations and seed without creating the encoder twice self._n_permutations = n_permutations self._seed = seed # create the encoder instance @@ -412,7 +544,7 @@ def __getstate__(self): # Get the state of the parent class state = super().__getstate__() # Remove the unpicklable property from the state - state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable + state.pop("mhfp_encoder", None) # mhfp_encoder is not picklable return state def __setstate__(self, state): @@ -422,10 +554,20 @@ def __setstate__(self, state): self._recreate_encoder() def _mol2fp(self, mol): - return self.mhfp_encoder.EncodeSECFPMol(mol, self.radius, self.rings, self.isomeric, self.kekulize, self.min_radius, self.length) + return self.mhfp_encoder.EncodeSECFPMol( + mol, + self.radius, + self.rings, + self.isomeric, + self.kekulize, + self.min_radius, + self.length, + ) def _recreate_encoder(self): - self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder(self._n_permutations, self._seed) + self.mhfp_encoder = rdMHFPFingerprint.MHFPEncoder( + self._n_permutations, self._seed + ) @property def seed(self): @@ -452,8 +594,20 @@ def nBits(self): # to be compliant with the requirement of the base class return self.length + class MorganFingerprintTransformer(FpsTransformer): - def __init__(self, nBits=2048, radius=2, useChirality=False, useBondTypes=True, useFeatures=False, useCounts=False, parallel: Union[bool, int] = False,): + def __init__( + self, + nBits=2048, + radius=2, + useChirality=False, + useBondTypes=True, + useFeatures=False, + useCounts=False, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): """Transform RDKit mols into Count or bit-based hashed MorganFingerprints Parameters @@ -471,7 +625,9 @@ def __init__(self, nBits=2048, radius=2, useChirality=False, useBondTypes=True, useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel = parallel) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.nBits = nBits self.radius = radius self.useChirality = useChirality @@ -482,19 +638,38 @@ def __init__(self, nBits=2048, radius=2, useChirality=False, useBondTypes=True, def _mol2fp(self, mol): if self.useCounts: return rdMolDescriptors.GetHashedMorganFingerprint( - mol,int(self.radius),nBits=int(self.nBits), useFeatures=bool(self.useFeatures), - useChirality=bool(self.useChirality), useBondTypes=bool(self.useBondTypes) + mol, + int(self.radius), + nBits=int(self.nBits), + useFeatures=bool(self.useFeatures), + useChirality=bool(self.useChirality), + useBondTypes=bool(self.useBondTypes), ) else: return rdMolDescriptors.GetMorganFingerprintAsBitVect( - mol,int(self.radius),nBits=int(self.nBits), useFeatures=bool(self.useFeatures), - useChirality=bool(self.useChirality), useBondTypes=bool(self.useBondTypes) + mol, + int(self.radius), + nBits=int(self.nBits), + useFeatures=bool(self.useFeatures), + useChirality=bool(self.useChirality), + useBondTypes=bool(self.useBondTypes), ) - + + class AvalonFingerprintTransformer(FpsTransformer): # Fingerprint from the Avalon toolkeit, https://doi.org/10.1021/ci050413p - def __init__(self, nBits:int = 512, isQuery:bool = False, resetVect:bool = False, bitFlags:int = 15761407, useCounts:bool = False, parallel: Union[bool, int] = False,): - """ Transform RDKit mols into Count or bit-based Avalon Fingerprints + def __init__( + self, + nBits: int = 512, + isQuery: bool = False, + resetVect: bool = False, + bitFlags: int = 15761407, + useCounts: bool = False, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): + """Transform RDKit mols into Count or bit-based Avalon Fingerprints Parameters ---------- @@ -509,35 +684,39 @@ def __init__(self, nBits:int = 512, isQuery:bool = False, resetVect:bool = False useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False """ - super().__init__(parallel = parallel) + super().__init__( + parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype + ) self.nBits = nBits self.isQuery = isQuery self.resetVect = resetVect self.bitFlags = bitFlags self.useCounts = useCounts - + def _mol2fp(self, mol): if self.useCounts: - return pyAvalonTools.GetAvalonCountFP(mol, - nBits=int(self.nBits), - isQuery=bool(self.isQuery), - bitFlags=int(self.bitFlags) + return pyAvalonTools.GetAvalonCountFP( + mol, + nBits=int(self.nBits), + isQuery=bool(self.isQuery), + bitFlags=int(self.bitFlags), ) else: - return pyAvalonTools.GetAvalonFP(mol, - nBits=int(self.nBits), - isQuery=bool(self.isQuery), - resetVect=bool(self.resetVect), - bitFlags=int(self.bitFlags) + return pyAvalonTools.GetAvalonFP( + mol, + nBits=int(self.nBits), + isQuery=bool(self.isQuery), + resetVect=bool(self.resetVect), + bitFlags=int(self.bitFlags), ) def parallel_helper(args): """Parallel_helper takes a tuple with classname, the objects parameters and the mols to process. Then instantiates the class with the parameters and processes the mol. - Intention is to be able to do this in chilcprocesses as some classes can't be pickled""" + Intention is to be able to do this in child processes as some classes can't be pickled""" classname, parameters, X_mols = args from scikit_mol import fingerprints + transformer = getattr(fingerprints, classname)(**parameters) return transformer._transform(X_mols) - diff --git a/scikit_mol/safeinference.py b/scikit_mol/safeinference.py new file mode 100644 index 0000000..401af4a --- /dev/null +++ b/scikit_mol/safeinference.py @@ -0,0 +1,172 @@ +"""Wrapper for sklearn estimators and pipelines to handle errors.""" + +from typing import Any + +import numpy as np +import pandas as pd +from functools import wraps +import warnings +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import check_array +from sklearn.utils.metaestimators import available_if + +from .utilities import set_safe_inference_mode + + +class MaskedArrayError(ValueError): + """Raised when a masked array is passed but safe_inference_mode is False.""" + + pass + + +def filter_invalid_rows(fill_value=np.nan, warn_on_invalid=False): + def decorator(func): + @wraps(func) + def wrapper(obj, X, y=None, *args, **kwargs): + if not getattr(obj, "safe_inference_mode", True): + if isinstance(X, np.ma.MaskedArray) and X.mask.any(): + raise MaskedArrayError( + f"Masked array detected with safe_inference_mode=False and {X.mask.any(axis=1).sum()} filtered rows. " + "Set safe_inference_mode=True to process masked arrays for inference of production models." + ) + return func(obj, X, y, *args, **kwargs) + + # Initialize valid_mask as all True + valid_mask = np.ones(X.shape[0], dtype=bool) + + # Handle masked arrays + if isinstance(X, np.ma.MaskedArray): + valid_mask &= ~X.mask.any(axis=1) + + # Handle non-finite values if required + if getattr(obj, "mask_nonfinite", True): + if isinstance(X, np.ma.MaskedArray): + valid_mask &= np.isfinite(X.data).all(axis=1) + else: + valid_mask &= np.isfinite(X).all(axis=1) + + if warn_on_invalid and not np.all(valid_mask): + warnings.warn( + f"SafeInferenceWrapper is in safe_inference_mode during use of {func.__name__} and invalid data detected. " + "This mode is intended for safe inference in production, not for training and evaluation.", + UserWarning, + ) + + valid_indices = np.where(valid_mask)[0] + reduced_X = X[valid_mask] + + if y is not None: + # TODO, how can we check y in the same way as the estimator? + y = check_array( + y, + force_all_finite=False, # accept_sparse="csr", + ensure_2d=False, + dtype=None, + input_name="y", + estimator=obj, + ) + reduced_y = y[valid_mask] + else: + reduced_y = None + + result = func(obj, reduced_X, reduced_y, *args, **kwargs) + + if result is None: + return None + + if isinstance(result, np.ndarray): + if result.ndim == 1: + output = np.full(X.shape[0], fill_value) + else: + output = np.full((X.shape[0], result.shape[1]), fill_value) + output[valid_indices] = result + return output + elif isinstance(result, pd.DataFrame): + output = pd.DataFrame(index=range(X.shape[0]), columns=result.columns) + output.iloc[valid_indices] = result + return output + elif isinstance(result, pd.Series): + output = pd.Series(index=range(X.shape[0]), dtype=result.dtype) + output.iloc[valid_indices] = result + return output + else: + return result + + return wrapper + + return decorator + + +class SafeInferenceWrapper(BaseEstimator, TransformerMixin): + """ + Wrapper for sklearn estimators to ensure safe inference in production environments. + + This wrapper is designed to be applied to trained models for use in production settings. + While it can be included during model development and training, the safe inference mode + should only be enabled when deploying models for inference in production. + + Parameters: + ----------- + estimator : BaseEstimator + The trained sklearn estimator to be wrapped. + safe_inference_mode : bool, default=False + If True, enables safeguards for handling invalid data during inference. + This should only be set to True when deploying models to production. + replace_value : any, default=np.nan + The value to use for replacing invalid data points. + """ + + def __init__( + self, + estimator: BaseEstimator, + safe_inference_mode: bool = False, + replace_value=np.nan, + mask_nonfinite: bool = True, + ): + self.estimator = estimator + self.safe_inference_mode = safe_inference_mode + self.replace_value = replace_value + self.mask_nonfinite = mask_nonfinite + + @property + def n_features_in_(self): + return self.estimator.n_features_in_ + + @filter_invalid_rows(warn_on_invalid=True) + def fit(self, X, y=None, **fit_params): + return self.estimator.fit(X, y, **fit_params) + + @available_if(lambda self: hasattr(self.estimator, "predict")) + @filter_invalid_rows() + def predict(self, X, y=None): + return self.estimator.predict(X) + + @available_if(lambda self: hasattr(self.estimator, "predict_proba")) + @filter_invalid_rows() + def predict_proba(self, X, y=None): + return self.estimator.predict_proba(X) + + @available_if(lambda self: hasattr(self.estimator, "decision_function")) + @filter_invalid_rows() + def decision_function(self, X, y=None): + return self.estimator.decision_function(X) + + @available_if(lambda self: hasattr(self.estimator, "transform")) + @filter_invalid_rows() + def transform(self, X, y=None): + return self.estimator.transform(X) + + @available_if(lambda self: hasattr(self.estimator, "fit_transform")) + @filter_invalid_rows(warn_on_invalid=True) + def fit_transform(self, X, y=None, **fit_params): + return self.estimator.fit_transform(X, y, **fit_params) + + @available_if(lambda self: hasattr(self.estimator, "score")) + @filter_invalid_rows(warn_on_invalid=True) + def score(self, X, y=None): + return self.estimator.score(X, y) + + @available_if(lambda self: hasattr(self.estimator, "get_feature_names_out")) + @filter_invalid_rows(warn_on_invalid=True) + def get_feature_names_out(self, *args, **kwargs): + return self.estimator.get_feature_names_out(*args, **kwargs) diff --git a/scikit_mol/standardizer.py b/scikit_mol/standardizer.py index 5277f20..76a8c55 100644 --- a/scikit_mol/standardizer.py +++ b/scikit_mol/standardizer.py @@ -8,27 +8,35 @@ from rdkit.rdBase import BlockLogs import numpy as np -from scikit_mol.core import check_transform_input, feature_names_default_mol +from scikit_mol.core import check_transform_input, feature_names_default_mol, InvalidMol class Standardizer(BaseEstimator, TransformerMixin): - """ Input a list of rdkit mols, output the same list but standardised - """ - def __init__(self, neutralize=True, parallel=False): + """Input a list of rdkit mols, output the same list but standardised""" + + def __init__(self, neutralize=True, parallel=False, safe_inference_mode=False): self.neutralize = neutralize self.parallel = parallel - self.start_method = None #TODO implement handling of start_method + self.start_method = None # TODO implement handling of start_method + self.safe_inference_mode = safe_inference_mode def fit(self, X, y=None): - return self - - def _transform(self, X): - block = BlockLogs() # Block all RDkit logging - arr = [] - for mol in X: + return self + + def _standardize_mol(self, mol): + if not mol: + if self.safe_inference_mode: + if isinstance(mol, InvalidMol): + return mol + else: + return InvalidMol(str(self), f"Invalid input molecule: {mol}") + else: + raise ValueError(f"Invalid input molecule: {mol}") + + try: + block = BlockLogs() # Block all RDkit logging # Normalizing functional groups - # https://molvs.readthedocs.io/en/latest/guide/standardize.html - clean_mol = rdMolStandardize.Cleanup(mol) + clean_mol = rdMolStandardize.Cleanup(mol) # Get parents fragments parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol) # Neutralise @@ -37,11 +45,16 @@ def _transform(self, X): uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol) else: uncharged_parent_clean_mol = parent_clean_mol - # Add to final list - arr.append(uncharged_parent_clean_mol) - - del block # Release logging block to previous state - return np.array(arr).reshape(-1,1) + del block # Release logging block to previous state + return uncharged_parent_clean_mol + except Exception as e: + if self.safe_inference_mode: + return InvalidMol(str(self), f"Standardization failed: {str(e)}") + else: + raise + + def _transform(self, X): + return np.array([self._standardize_mol(mol) for mol in X]).reshape(-1, 1) @feature_names_default_mol def get_feature_names_out(self, input_features=None): @@ -53,25 +66,37 @@ def transform(self, X, y=None): return self._transform(X) elif self.parallel: - n_processes = self.parallel if self.parallel > 1 else None # Pool(processes=None) autodetects - n_chunks = n_processes*2 if n_processes is not None else multiprocessing.cpu_count()*2 #TODO, tune the number of chunks per child process - - with multiprocessing.get_context(self.start_method).Pool(processes=n_processes) as pool: + n_processes = ( + self.parallel if self.parallel > 1 else None + ) # Pool(processes=None) autodetects + n_chunks = ( + n_processes * 2 + if n_processes is not None + else multiprocessing.cpu_count() * 2 + ) # TODO, tune the number of chunks per child process + + with multiprocessing.get_context(self.start_method).Pool( + processes=n_processes + ) as pool: x_chunks = np.array_split(X, n_chunks) - #TODO check what is fastest, pickle or recreate and do this only for classes that need this - #arrays = pool.map(self._transform, x_chunks) parameters = self.get_params() - arrays = pool.map(parallel_helper, [(self.__class__.__name__, parameters, x_chunk) for x_chunk in x_chunks]) + arrays = pool.map( + parallel_helper, + [ + (self.__class__.__name__, parameters, x_chunk) + for x_chunk in x_chunks + ], + ) arr = np.concatenate(arrays) return arr - def parallel_helper(args): """Parallel_helper takes a tuple with classname, the objects parameters and the mols to process. Then instantiates the class with the parameters and processes the mol. - Intention is to be able to do this in chilcprocesses as some classes can't be pickled""" + Intention is to be able to do this in child processes as some classes can't be pickled""" classname, parameters, X_mols = args from scikit_mol import standardizer + transformer = getattr(standardizer, classname)(**parameters) return transformer._transform(X_mols) diff --git a/scikit_mol/utilities.py b/scikit_mol/utilities.py index 70eac51..13c360e 100644 --- a/scikit_mol/utilities.py +++ b/scikit_mol/utilities.py @@ -1,12 +1,19 @@ -#For a non-scikit-learn check smiles sanitizer class +# For a non-scikit-learn check smiles sanitizer class + import pandas as pd from rdkit import Chem +from sklearn.base import BaseEstimator +from sklearn.pipeline import Pipeline, FeatureUnion +from sklearn.compose import ColumnTransformer +import warnings + + class CheckSmilesSanitazion: def __init__(self, return_mol=False): self.return_mol = return_mol self.errors = pd.DataFrame() - + def sanitize(self, X_smiles_list, y=None): if y: y_out = [] @@ -27,9 +34,11 @@ def sanitize(self, X_smiles_list, y=None): y_errors.append(y_value) if X_errors: - print(f'Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors') + print( + f"Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors" + ) - self.errors = pd.DataFrame({'SMILES':X_errors, 'y':y_errors}) + self.errors = pd.DataFrame({"SMILES": X_errors, "y": y_errors}) return X_out, y_out, X_errors, y_errors @@ -48,8 +57,69 @@ def sanitize(self, X_smiles_list, y=None): X_errors.append(smiles) if X_errors: - print(f'Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors') + print( + f"Error in parsing {len(X_errors)} SMILES. Unparsable SMILES can be found in self.errors" + ) - self.errors = pd.DataFrame({'SMILES':X_errors}) + self.errors = pd.DataFrame({"SMILES": X_errors}) return X_out, X_errors + + +def set_safe_inference_mode(estimator, value): + """ + Recursively set the safe_inference_mode parameter for all compatible estimators. + + :param estimator: A scikit-learn estimator, pipeline, or custom wrapper + :param value: Boolean value to set for safe_inference_mode + """ + + def _set_safe_inference_mode_recursive(est, val): + if hasattr(est, "safe_inference_mode"): + est.safe_inference_mode = val + + # Handle Pipeline + if isinstance(est, Pipeline): + for _, step in est.steps: + _set_safe_inference_mode_recursive(step, val) + + # Handle FeatureUnion + elif isinstance(est, FeatureUnion): + for _, transformer in est.transformer_list: + _set_safe_inference_mode_recursive(transformer, val) + + # Handle ColumnTransformer + elif isinstance(est, ColumnTransformer): + for _, transformer, _ in est.transformers: + _set_safe_inference_mode_recursive(transformer, val) + + # Handle SafeInferenceWrapper + elif hasattr(est, "estimator") and isinstance(est.estimator, BaseEstimator): + _set_safe_inference_mode_recursive(est.estimator, val) + + # Handle other estimators with get_params + elif isinstance(est, BaseEstimator): + params = est.get_params(deep=False) + for param_name, param_value in params.items(): + if isinstance(param_value, BaseEstimator): + _set_safe_inference_mode_recursive(param_value, val) + + # Apply the recursive function + _set_safe_inference_mode_recursive(estimator, value) + + # Final check + params = estimator.get_params(deep=True) + mismatched_params = [ + key.rstrip("__safe_inference_mode") + for key, val in params.items() + if key.endswith("__safe_inference_mode") and val != value + ] + + if mismatched_params: + warnings.warn( + f"The following components have 'safe_inference_mode' set to a different value than requested: {mismatched_params}. " + "This could be due to nested estimators that were not properly handled.", + UserWarning, + ) + + return estimator diff --git a/tests/fixtures.py b/tests/fixtures.py index 8cf0751..2b5a2e6 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -10,29 +10,44 @@ from sklearn.preprocessing import FunctionTransformer from sklearn.pipeline import make_pipeline from sklearn.compose import make_column_selector, make_column_transformer -from scikit_mol.fingerprints import MACCSKeysFingerprintTransformer, RDKitFingerprintTransformer, AtomPairFingerprintTransformer, \ - TopologicalTorsionFingerprintTransformer, MorganFingerprintTransformer, SECFingerprintTransformer, \ - MHFingerprintTransformer, AvalonFingerprintTransformer +from scikit_mol.fingerprints import ( + MACCSKeysFingerprintTransformer, + RDKitFingerprintTransformer, + AtomPairFingerprintTransformer, + TopologicalTorsionFingerprintTransformer, + MorganFingerprintTransformer, + SECFingerprintTransformer, + MHFingerprintTransformer, + AvalonFingerprintTransformer, +) from scikit_mol.descriptors import MolecularDescriptorTransformer from scikit_mol.conversions import SmilesToMolTransformer from scikit_mol.standardizer import Standardizer -from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT, DEFAULT_MOL_COLUMN_NAME +from scikit_mol.core import ( + SKLEARN_VERSION_PANDAS_OUT, + DEFAULT_MOL_COLUMN_NAME, + InvalidMol, +) -#TODO these should really go into the conftest.py, so that they are automatically imported in the tests +# TODO these should really go into the conftest.py, so that they are automatically imported in the tests _SMILES_LIST = [ - 'O=C(O)c1ccccc1', - 'O=C([O-])c1ccccc1', - 'O=C([O-])c1ccccc1.[Na+]', - 'O=C(O[Na])c1ccccc1', - 'C[N+](C)C.O=C([O-])c1ccccc1', + "O=C(O)c1ccccc1", + "O=C([O-])c1ccccc1", + "O=C([O-])c1ccccc1.[Na+]", + "O=C(O[Na])c1ccccc1", + "C[N+](C)C.O=C([O-])c1ccccc1", +] +_CANONICAL_SMILES_LIST = [ + Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in _SMILES_LIST ] -_CANONICAL_SMILES_LIST = [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in _SMILES_LIST] + @pytest.fixture def smiles_list(): return _CANONICAL_SMILES_LIST.copy() + _CONTAINER_CREATORS = [ lambda x: x, lambda x: np.array(x), @@ -48,77 +63,122 @@ def smiles_list(): ] for name in _names_to_test: _CONTAINER_CREATORS.append(lambda x, name=name: pd.Series(x, name=name)) - _CONTAINER_CREATORS.append(lambda x, name=name: pd.DataFrame({name: x}) if name else pd.DataFrame(x)) + _CONTAINER_CREATORS.append( + lambda x, name=name: pd.DataFrame({name: x}) if name else pd.DataFrame(x) + ) + -@pytest.fixture(params=[container(_CANONICAL_SMILES_LIST) for container in _CONTAINER_CREATORS] +@pytest.fixture( + params=[container(_CANONICAL_SMILES_LIST) for container in _CONTAINER_CREATORS] ) -def smiles_container(request, ): +def smiles_container( + request, +): return request.param.copy() + @pytest.fixture -def chiral_smiles_list(): #Need to be a certain size, so the fingerprints reacts to different max_lenǵths and radii - return [Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) for smiles in [ - 'N[C@@H](C)C(=O)OCCCCCCCCCCCC', - 'C1C[C@H]2CCCC[C@H]2CC1CCCCCCCCC', - 'N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]']] +def chiral_smiles_list(): # Need to be a certain size, so the fingerprints reacts to different max_lenǵths and radii + return [ + Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) + for smiles in [ + "N[C@@H](C)C(=O)OCCCCCCCCCCCC", + "C1C[C@H]2CCCC[C@H]2CC1CCCCCCCCC", + "N[C@@H](C)C(=O)Oc1ccccc1CCCCCCCCCCCCCCCCCCN[H]", + ] + ] + @pytest.fixture def invalid_smiles_list(smiles_list): - smiles_list.append('Invalid') + smiles_list = smiles_list.copy() + smiles_list.append("Invalid") return smiles_list + _MOLS_LIST = [Chem.MolFromSmiles(smiles) for smiles in _SMILES_LIST] + @pytest.fixture def mols_list(): return _MOLS_LIST.copy() + @pytest.fixture(params=[container(_MOLS_LIST) for container in _CONTAINER_CREATORS]) def mols_container(request): return request.param.copy() + @pytest.fixture def chiral_mols_list(chiral_smiles_list): return [Chem.MolFromSmiles(smiles) for smiles in chiral_smiles_list] +@pytest.fixture +def mols_with_invalid_container(invalid_smiles_list): + mols = [] + for smiles in invalid_smiles_list: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + mols.append(InvalidMol("TestError", f"Invalid SMILES: {smiles}")) + else: + mols.append(mol) + return mols + + @pytest.fixture def fingerprint(mols_list): - return rdMolDescriptors.GetHashedMorganFingerprint(mols_list[0],2,nBits=1000) + return rdMolDescriptors.GetHashedMorganFingerprint(mols_list[0], 2, nBits=1000) + _DIR_DATA = Path(__file__).parent / "data" _FILE_SLC6A4 = _DIR_DATA / "SLC6A4_active_excapedb_subset.csv" _FILE_SLC6A4_WITH_CDDD = _DIR_DATA / "CDDD_SLC6A4_active_excapedb_subset.csv.gz" + @pytest.fixture def SLC6A4_subset(): data = pd.read_csv(_FILE_SLC6A4) return data + @pytest.fixture def SLC6A4_subset_with_cddd(SLC6A4_subset): data = SLC6A4_subset.copy().drop_duplicates(subset="Ambit_InchiKey") cddd = pd.read_csv(_FILE_SLC6A4_WITH_CDDD, index_col="Ambit_InchiKey") - data = data.merge(cddd, left_on="Ambit_InchiKey", right_index=True, how="inner", validate="one_to_one") + data = data.merge( + cddd, + left_on="Ambit_InchiKey", + right_index=True, + how="inner", + validate="one_to_one", + ) return data -skip_pandas_output_test = pytest.mark.skipif(Version(sklearn.__version__) < SKLEARN_VERSION_PANDAS_OUT, reason=f"requires scikit-learn {SKLEARN_VERSION_PANDAS_OUT} or higher") + +skip_pandas_output_test = pytest.mark.skipif( + Version(sklearn.__version__) < SKLEARN_VERSION_PANDAS_OUT, + reason=f"requires scikit-learn {SKLEARN_VERSION_PANDAS_OUT} or higher", +) _FEATURIZER_CLASSES = [ - MACCSKeysFingerprintTransformer, - RDKitFingerprintTransformer, - AtomPairFingerprintTransformer, - TopologicalTorsionFingerprintTransformer, - MorganFingerprintTransformer, - SECFingerprintTransformer, - MHFingerprintTransformer, - AvalonFingerprintTransformer, - MolecularDescriptorTransformer, - ] + MACCSKeysFingerprintTransformer, + RDKitFingerprintTransformer, + AtomPairFingerprintTransformer, + TopologicalTorsionFingerprintTransformer, + MorganFingerprintTransformer, + SECFingerprintTransformer, + MHFingerprintTransformer, + AvalonFingerprintTransformer, + MolecularDescriptorTransformer, +] + + @pytest.fixture(params=_FEATURIZER_CLASSES) def featurizer(request): return request.param() + @pytest.fixture def combined_transformer(featurizer): descriptors_pipeline = make_pipeline( @@ -136,4 +196,4 @@ def combined_transformer(featurizer): (identity_pipeline, make_column_selector(pattern=r"^cddd_\d+$")), remainder="drop", ) - return transformer \ No newline at end of file + return transformer diff --git a/tests/test_desctransformer.py b/tests/test_desctransformer.py index 959f9fc..6877def 100644 --- a/tests/test_desctransformer.py +++ b/tests/test_desctransformer.py @@ -1,14 +1,23 @@ import time -import pytest +import pytest import numpy as np import pandas as pd +import numpy.ma as ma from rdkit.Chem import Descriptors import sklearn from packaging.version import Version from scikit_mol.conversions import SmilesToMolTransformer from scikit_mol.descriptors import MolecularDescriptorTransformer from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT -from fixtures import mols_list, smiles_list, mols_container, smiles_container, skip_pandas_output_test +from fixtures import ( + mols_list, + smiles_list, + invalid_smiles_list, + mols_container, + smiles_container, + skip_pandas_output_test, + mols_with_invalid_container, +) from sklearn import clone from sklearn.pipeline import Pipeline import joblib @@ -18,79 +27,89 @@ def default_descriptor_transformer(): return MolecularDescriptorTransformer() + @pytest.fixture def selected_descriptor_transformer(): - return MolecularDescriptorTransformer(desc_list=['HeavyAtomCount', 'FractionCSP3', 'RingCount', 'MolLogP', 'MolWt']) + return MolecularDescriptorTransformer( + desc_list=["HeavyAtomCount", "FractionCSP3", "RingCount", "MolLogP", "MolWt"] + ) + -def test_descriptor_transformer_clonability( default_descriptor_transformer): - for t in [ default_descriptor_transformer]: - params = t.get_params() +def test_descriptor_transformer_clonability(default_descriptor_transformer): + for t in [default_descriptor_transformer]: + params = t.get_params() t2 = clone(t) params_2 = t2.get_params() - #Parameters of cloned transformers should be the same - assert all([ params[key] == params_2[key] for key in params.keys()]) - #Cloned transformers should not be the same object + # Parameters of cloned transformers should be the same + assert all([params[key] == params_2[key] for key in params.keys()]) + # Cloned transformers should not be the same object assert t2 != t + def test_descriptor_transformer_set_params(default_descriptor_transformer): for t in [default_descriptor_transformer]: - params = t.get_params() - #change extracted dictionary - params['desc_list'] = ['HeavyAtomCount', 'FractionCSP3'] - #change params in transformer - t.set_params(desc_list = ['HeavyAtomCount', 'FractionCSP3']) + params = t.get_params() + # change extracted dictionary + params["desc_list"] = ["HeavyAtomCount", "FractionCSP3"] + # change params in transformer + t.set_params(desc_list=["HeavyAtomCount", "FractionCSP3"]) # get parameters as dictionary and assert that it is the same params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) assert len(default_descriptor_transformer.selected_descriptors) == 2 -def test_descriptor_transformer_available_descriptors(default_descriptor_transformer, selected_descriptor_transformer): - #Default have as many as in RDkit and all are selected - assert (len(default_descriptor_transformer.available_descriptors) == len(Descriptors._descList)) - assert (len(default_descriptor_transformer.selected_descriptors) == len(Descriptors._descList)) - #Default have as many as in RDkit but only 5 are selected - assert (len(selected_descriptor_transformer.available_descriptors) == len(Descriptors._descList)) - assert (len(selected_descriptor_transformer.selected_descriptors) == 5) - -def test_descriptor_transformer_transform(mols_container, default_descriptor_transformer): +def test_descriptor_transformer_available_descriptors( + default_descriptor_transformer, selected_descriptor_transformer +): + # Default have as many as in RDkit and all are selected + assert len(default_descriptor_transformer.available_descriptors) == len( + Descriptors._descList + ) + assert len(default_descriptor_transformer.selected_descriptors) == len( + Descriptors._descList + ) + # Default have as many as in RDkit but only 5 are selected + assert len(selected_descriptor_transformer.available_descriptors) == len( + Descriptors._descList + ) + assert len(selected_descriptor_transformer.selected_descriptors) == 5 + + +def test_descriptor_transformer_transform( + mols_container, default_descriptor_transformer +): features = default_descriptor_transformer.transform(mols_container) - assert(len(features) == len(mols_container)) - assert(len(features[0]) == len(Descriptors._descList)) - + assert len(features) == len(mols_container) + assert len(features[0]) == len(Descriptors._descList) + + def test_descriptor_transformer_wrong_descriptors(): with pytest.raises(AssertionError): - MolecularDescriptorTransformer(desc_list=['Color', 'Icecream content', 'ChokolateDarkness', 'Content42', 'MolWt']) - + MolecularDescriptorTransformer( + desc_list=[ + "Color", + "Icecream content", + "ChokolateDarkness", + "Content42", + "MolWt", + ] + ) def test_descriptor_transformer_parallel(mols_list, default_descriptor_transformer): default_descriptor_transformer.set_params(parallel=True) features = default_descriptor_transformer.transform(mols_list) - assert(len(features) == len(mols_list)) - assert(len(features[0]) == len(Descriptors._descList)) - #Now with Rdkit 2022.3 creating a second transformer and running it, froze the process - transformer2 = MolecularDescriptorTransformer(**default_descriptor_transformer.get_params()) + assert len(features) == len(mols_list) + assert len(features[0]) == len(Descriptors._descList) + # Now with Rdkit 2022.3 creating a second transformer and running it, froze the process + transformer2 = MolecularDescriptorTransformer( + **default_descriptor_transformer.get_params() + ) features2 = transformer2.transform(mols_list) - assert(len(features2) == len(mols_list)) - assert(len(features2[0]) == len(Descriptors._descList)) - - -@skip_pandas_output_test -def test_descriptor_transformer_pandas_output(mols_container, default_descriptor_transformer, selected_descriptor_transformer, pandas_output): - for transformer in [default_descriptor_transformer, selected_descriptor_transformer]: - features = transformer.transform(mols_container) - assert isinstance(features, pd.DataFrame) - assert features.shape[0] == len(mols_container) - assert features.columns.tolist() == transformer.selected_descriptors + assert len(features2) == len(mols_list) + assert len(features2[0]) == len(Descriptors._descList) -@skip_pandas_output_test -def test_descriptor_transformer_pandas_output_pipeline(smiles_container, default_descriptor_transformer, pandas_output): - pipeline = Pipeline([("s2m", SmilesToMolTransformer()), ("desc", default_descriptor_transformer)]) - features = pipeline.fit_transform(smiles_container) - assert isinstance(features, pd.DataFrame) - assert features.shape[0] == len(smiles_container) - assert features.columns.tolist() == default_descriptor_transformer.selected_descriptors # This test may fail on windows and mac (due to spawn rather than fork?) # def test_descriptor_transformer_parallel_speedup(mols_list, default_descriptor_transformer): @@ -100,7 +119,7 @@ def test_descriptor_transformer_pandas_output_pipeline(smiles_container, default # t0 = time.time() # features = default_descriptor_transformer.transform(mols_list) # t_single = time.time()-t0 - + # default_descriptor_transformer.set_params(parallel=True) # t0 = time.time() # features = default_descriptor_transformer.transform(mols_list) @@ -108,7 +127,94 @@ def test_descriptor_transformer_pandas_output_pipeline(smiles_container, default # assert(t_par < t_single/(n_phys_cpus/1.5)) # div by 1.5 as we don't assume full speedup - - +def test_transform_with_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer(safe_inference_mode=True) + descriptors = transformer.transform(mols_with_invalid_container) + + assert isinstance(descriptors, ma.MaskedArray) + assert len(descriptors) == len(mols_with_invalid_container) + + # Check that the last row (corresponding to the InvalidMol) is fully masked + assert np.all(descriptors.mask[-1]) + + # Check that other rows are not masked + assert not np.any(descriptors.mask[:-1]) + + +def test_transform_without_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer(safe_inference_mode=False) + with pytest.raises( + Exception + ): # You might want to be more specific about the exception type + transformer.transform(mols_with_invalid_container) + +def test_transform_parallel_with_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer( + safe_inference_mode=True, parallel=True + ) + descriptors = transformer.transform(mols_with_invalid_container) + + assert isinstance(descriptors, ma.MaskedArray) + assert len(descriptors) == len(mols_with_invalid_container) + + # Check that the last row (corresponding to the InvalidMol) is fully masked + assert np.all(descriptors.mask[-1]) + + # Check that other rows are not masked + assert not np.any(descriptors.mask[:-1]) + + +def test_transform_parallel_without_safe_inference_mode(mols_with_invalid_container): + transformer = MolecularDescriptorTransformer( + safe_inference_mode=False, parallel=True + ) + with pytest.raises( + Exception + ): # You might want to be more specific about the exception type + transformer.transform(mols_with_invalid_container) + + +def test_safe_inference_mode_setting(): + transformer = MolecularDescriptorTransformer() + assert not transformer.safe_inference_mode # Default should be False + + transformer.set_params(safe_inference_mode=True) + assert transformer.safe_inference_mode + + transformer.set_params(safe_inference_mode=False) + assert not transformer.safe_inference_mode + + +# TODO, if these tests are run before the others, these tests will fail, probably due to pandas output? +@skip_pandas_output_test +def test_descriptor_transformer_pandas_output( + mols_container, + default_descriptor_transformer, + selected_descriptor_transformer, + pandas_output, +): + for transformer in [ + default_descriptor_transformer, + selected_descriptor_transformer, + ]: + features = transformer.transform(mols_container) + assert isinstance(features, pd.DataFrame) + assert features.shape[0] == len(mols_container) + assert features.columns.tolist() == transformer.selected_descriptors + + +@skip_pandas_output_test +def test_descriptor_transformer_pandas_output_pipeline( + smiles_container, default_descriptor_transformer, pandas_output +): + pipeline = Pipeline( + [("s2m", SmilesToMolTransformer()), ("desc", default_descriptor_transformer)] + ) + features = pipeline.fit_transform(smiles_container) + assert isinstance(features, pd.DataFrame) + assert features.shape[0] == len(smiles_container) + assert ( + features.columns.tolist() == default_descriptor_transformer.selected_descriptors + ) diff --git a/tests/test_fptransformers.py b/tests/test_fptransformers.py index d149f3a..9a9c27a 100644 --- a/tests/test_fptransformers.py +++ b/tests/test_fptransformers.py @@ -4,153 +4,276 @@ import numpy as np import pandas as pd from rdkit import Chem -from fixtures import mols_list, smiles_list, mols_container, smiles_container, fingerprint, chiral_smiles_list, chiral_mols_list +from fixtures import ( + mols_list, + smiles_list, + mols_container, + smiles_container, + fingerprint, + chiral_smiles_list, + chiral_mols_list, + mols_with_invalid_container, + invalid_smiles_list, +) from sklearn import clone -from scikit_mol.fingerprints import MorganFingerprintTransformer, MACCSKeysFingerprintTransformer, RDKitFingerprintTransformer, AtomPairFingerprintTransformer, TopologicalTorsionFingerprintTransformer, SECFingerprintTransformer, MHFingerprintTransformer, AvalonFingerprintTransformer - +from scikit_mol.fingerprints import ( + MorganFingerprintTransformer, + MACCSKeysFingerprintTransformer, + RDKitFingerprintTransformer, + AtomPairFingerprintTransformer, + TopologicalTorsionFingerprintTransformer, + SECFingerprintTransformer, + MHFingerprintTransformer, + AvalonFingerprintTransformer, +) @pytest.fixture def morgan_transformer(): return MorganFingerprintTransformer() + @pytest.fixture def rdkit_transformer(): return RDKitFingerprintTransformer() + @pytest.fixture def atompair_transformer(): return AtomPairFingerprintTransformer() + @pytest.fixture def topologicaltorsion_transformer(): return TopologicalTorsionFingerprintTransformer() + @pytest.fixture def maccs_transformer(): return MACCSKeysFingerprintTransformer() + @pytest.fixture def secfp_transformer(): return SECFingerprintTransformer() + @pytest.fixture def mhfp_transformer(): return MHFingerprintTransformer() + @pytest.fixture def avalon_transformer(): return AvalonFingerprintTransformer() + def test_fpstransformer_fp2array(morgan_transformer, fingerprint): fp = morgan_transformer._fp2array(fingerprint) - #See that fp is the correct type, shape and bit count - assert(type(fp) == type(np.array([0]))) - assert(fp.shape == (1000,)) - assert(fp.sum() == 25) + # See that fp is the correct type, shape and bit count + assert type(fp) == type(np.array([0])) + assert fp.shape == (1000,) + assert fp.sum() == 25 + def test_fpstransformer_transform_mol(morgan_transformer, mols_list): fp = morgan_transformer._transform_mol(mols_list[0]) - #See that fp is the correct type, shape and bit count - assert(type(fp) == type(np.array([0]))) - assert(fp.shape == (2048,)) - assert(fp.sum() == 14) - -def test_clonability(maccs_transformer, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - for t in [maccs_transformer, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, secfp_transformer, mhfp_transformer, avalon_transformer]: - params = t.get_params() + # See that fp is the correct type, shape and bit count + assert type(fp) == type(np.array([0])) + assert fp.shape == (2048,) + assert fp.sum() == 14 + + +def test_clonability( + maccs_transformer, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + for t in [ + maccs_transformer, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, + ]: + params = t.get_params() t2 = clone(t) params_2 = t2.get_params() - #Parameters of cloned transformers should be the same - assert all([ params[key] == params_2[key] for key in params.keys()]) - #Cloned transformers should not be the same object + # Parameters of cloned transformers should be the same + assert all([params[key] == params_2[key] for key in params.keys()]) + # Cloned transformers should not be the same object assert t2 != t -def test_set_params(morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, avalon_transformer]: - params = t.get_params() - #change extracted dictionary - params['nBits'] = 4242 - #change params in transformer - t.set_params(nBits = 4242) + +def test_set_params( + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + avalon_transformer, + ]: + params = t.get_params() + # change extracted dictionary + params["nBits"] = 4242 + # change params in transformer + t.set_params(nBits=4242) # get parameters as dictionary and assert that it is the same params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) for t in [rdkit_transformer]: - params = t.get_params() - params['fpSize'] = 4242 - t.set_params(fpSize = 4242) + params = t.get_params() + params["fpSize"] = 4242 + t.set_params(fpSize=4242) params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) for t in [secfp_transformer]: - params = t.get_params() - params['length'] = 4242 - t.set_params(length = 4242) + params = t.get_params() + params["length"] = 4242 + t.set_params(length=4242) params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) for t in [mhfp_transformer]: - params = t.get_params() - params['n_permutations'] = 4242 - t.set_params(n_permutations = 4242) + params = t.get_params() + params["n_permutations"] = 4242 + t.set_params(n_permutations=4242) params_2 = t.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) - -def test_transform(mols_container, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - #Test the different transformers - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, rdkit_transformer, secfp_transformer, mhfp_transformer, avalon_transformer]: - params = t.get_params() + assert all([params[key] == params_2[key] for key in params.keys()]) + + +def test_transform( + mols_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + # Test the different transformers + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, + ]: + params = t.get_params() fps = t.transform(mols_container) - #Assert that the same length of input and output + # Assert that the same length of input and output assert len(fps) == len(mols_container) # assert that the size of the fingerprint is the expected size - if type(t) == type(maccs_transformer) or type(t) == type(secfp_transformer) or type(t) == type(mhfp_transformer): + if ( + type(t) == type(maccs_transformer) + or type(t) == type(secfp_transformer) + or type(t) == type(mhfp_transformer) + ): fpsize = t.nBits elif type(t) == type(rdkit_transformer): - fpsize = params['fpSize'] + fpsize = params["fpSize"] else: - fpsize = params['nBits'] - + fpsize = params["nBits"] + assert len(fps[0]) == fpsize -def test_transform_parallel(mols_container, morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, secfp_transformer, mhfp_transformer, avalon_transformer): - #Test the different transformers - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, rdkit_transformer, secfp_transformer, mhfp_transformer, avalon_transformer]: + +def test_transform_parallel( + mols_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, +): + # Test the different transformers + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + mhfp_transformer, + avalon_transformer, + ]: t.set_params(parallel=True) - params = t.get_params() + params = t.get_params() fps = t.transform(mols_container) - #Assert that the same length of input and output + # Assert that the same length of input and output assert len(fps) == len(mols_container) # assert that the size of the fingerprint is the expected size - if type(t) == type(maccs_transformer) or type(t) == type(secfp_transformer) or type(t) == type(mhfp_transformer): + if ( + type(t) == type(maccs_transformer) + or type(t) == type(secfp_transformer) + or type(t) == type(mhfp_transformer) + ): fpsize = t.nBits elif type(t) == type(rdkit_transformer): - fpsize = params['fpSize'] + fpsize = params["fpSize"] else: - fpsize = params['nBits'] - + fpsize = params["nBits"] + assert len(fps[0]) == fpsize -def test_picklable(morgan_transformer, rdkit_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, secfp_transformer, avalon_transformer): - #Test the different transformers - for t in [morgan_transformer, atompair_transformer, topologicaltorsion_transformer, maccs_transformer, rdkit_transformer, secfp_transformer, avalon_transformer]: +def test_picklable( + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, +): + # Test the different transformers + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: with tempfile.NamedTemporaryFile() as f: pickle.dump(t, f) f.seek(0) t2 = pickle.load(f) - assert(t.get_params() == t2.get_params()) - + assert t.get_params() == t2.get_params() + def assert_transformer_set_params(tr_class, new_params, mols_list): default_params = tr_class().get_params() for key in new_params.keys(): - tr = tr_class() params = tr.get_params() params[key] = new_params[key] @@ -164,20 +287,36 @@ def assert_transformer_set_params(tr_class, new_params, mols_list): fps_init_new_params = new_tr.transform(mols_list) # Now fp_default should not be the same as fp_reset_params - assert ~np.all([np.array_equal(fp_default, fp_reset_params) for fp_default, fp_reset_params in zip(fps_default, fps_reset_params)]), f"Assertation error, FP appears the same, although the {key} should be changed from {default_params[key]} to {params[key]}" + assert ~np.all( + [ + np.array_equal(fp_default, fp_reset_params) + for fp_default, fp_reset_params in zip(fps_default, fps_reset_params) + ] + ), f"Assertation error, FP appears the same, although the {key} should be changed from {default_params[key]} to {params[key]}" # fp_reset_params and fp_init_new_params should however be the same - assert np.all([np.array_equal(fp_init_new_params, fp_reset_params) for fp_init_new_params, fp_reset_params in zip(fps_init_new_params, fps_reset_params)]) , f"Assertation error, FP appears to be different, although the {key} should be changed back as well as initialized to {params[key]}" + assert np.all( + [ + np.array_equal(fp_init_new_params, fp_reset_params) + for fp_init_new_params, fp_reset_params in zip( + fps_init_new_params, fps_reset_params + ) + ] + ), f"Assertation error, FP appears to be different, although the {key} should be changed back as well as initialized to {params[key]}" def test_morgan_set_params(chiral_mols_list): - new_params = {'nBits': 1024, - 'radius': 1, - 'useBondTypes': False,# TODO, why doesn't this change the FP? - 'useChirality': True, - 'useCounts': True, - 'useFeatures': True} - - assert_transformer_set_params(MorganFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "nBits": 1024, + "radius": 1, + "useBondTypes": False, # TODO, why doesn't this change the FP? + "useChirality": True, + "useCounts": True, + "useFeatures": True, + } + + assert_transformer_set_params( + MorganFingerprintTransformer, new_params, chiral_mols_list + ) def test_atompairs_set_params(chiral_mols_list): @@ -186,71 +325,186 @@ def test_atompairs_set_params(chiral_mols_list): #'confId': -1, #'fromAtoms': 1, #'ignoreAtoms': 0, - 'includeChirality': True, - 'maxLength': 3, - 'minLength': 3, - 'nBits': 1024, - 'nBitsPerEntry': 3, + "includeChirality": True, + "maxLength": 3, + "minLength": 3, + "nBits": 1024, + "nBitsPerEntry": 3, #'use2D': True, #TODO, understand why this can't be set different - 'useCounts': True} - - assert_transformer_set_params(AtomPairFingerprintTransformer, new_params, chiral_mols_list) + "useCounts": True, + } + + assert_transformer_set_params( + AtomPairFingerprintTransformer, new_params, chiral_mols_list + ) def test_topologicaltorsion_set_params(chiral_mols_list): - new_params = {#'atomInvariants': 0, - #'fromAtoms': 0, - #'ignoreAtoms': 0, - #'includeChirality': True, #TODO, figure out why this setting seems to give same FP wheter toggled or not - 'nBits': 1024, - 'nBitsPerEntry': 3, - 'targetSize': 5, - 'useCounts': True} - - assert_transformer_set_params(TopologicalTorsionFingerprintTransformer, new_params, chiral_mols_list) + new_params = { #'atomInvariants': 0, + #'fromAtoms': 0, + #'ignoreAtoms': 0, + #'includeChirality': True, #TODO, figure out why this setting seems to give same FP wheter toggled or not + "nBits": 1024, + "nBitsPerEntry": 3, + "targetSize": 5, + "useCounts": True, + } + + assert_transformer_set_params( + TopologicalTorsionFingerprintTransformer, new_params, chiral_mols_list + ) + def test_RDKitFPTransformer(chiral_mols_list): - new_params = {#'atomInvariantsGenerator': None, - #'branchedPaths': False, - #'countBounds': 0, #TODO: What does this do? - 'countSimulation': True, - 'fpSize': 1024, - 'maxPath': 3, - 'minPath': 2, - 'numBitsPerFeature': 3, - 'useBondOrder': False, #TODO, why doesn't this change the FP? - #'useHs': False, #TODO, why doesn't this change the FP? - } - assert_transformer_set_params(RDKitFingerprintTransformer, new_params, chiral_mols_list) + new_params = { #'atomInvariantsGenerator': None, + #'branchedPaths': False, + #'countBounds': 0, #TODO: What does this do? + "countSimulation": True, + "fpSize": 1024, + "maxPath": 3, + "minPath": 2, + "numBitsPerFeature": 3, + "useBondOrder": False, # TODO, why doesn't this change the FP? + #'useHs': False, #TODO, why doesn't this change the FP? + } + assert_transformer_set_params( + RDKitFingerprintTransformer, new_params, chiral_mols_list + ) def test_SECFingerprintTransformer(chiral_mols_list): - new_params = {'isomeric': True, - 'kekulize': True, - 'length': 1048, - 'min_radius': 2, - #'n_permutations': 2, # The SECFp is not using this setting - 'radius': 2, - 'rings': False, - #'seed': 1 # The SECFp is not using this setting - } - assert_transformer_set_params(SECFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "isomeric": True, + "kekulize": True, + "length": 1048, + "min_radius": 2, + #'n_permutations': 2, # The SECFp is not using this setting + "radius": 2, + "rings": False, + #'seed': 1 # The SECFp is not using this setting + } + assert_transformer_set_params( + SECFingerprintTransformer, new_params, chiral_mols_list + ) + def test_MHFingerprintTransformer(chiral_mols_list): - new_params = {'radius': 2, - 'rings': False, - 'isomeric': True, - 'kekulize': True, - 'min_radius': 2, - 'n_permutations': 4096, - 'seed': 44 - } - assert_transformer_set_params(MHFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "radius": 2, + "rings": False, + "isomeric": True, + "kekulize": True, + "min_radius": 2, + "n_permutations": 4096, + "seed": 44, + } + assert_transformer_set_params( + MHFingerprintTransformer, new_params, chiral_mols_list + ) + def test_AvalonFingerprintTransformer(chiral_mols_list): - new_params = {'nBits': 1024, - 'isQuery': True, - # 'resetVect': True, #TODO: this doesn't change the FP - 'bitFlags': 32767 - } - assert_transformer_set_params(AvalonFingerprintTransformer, new_params, chiral_mols_list) + new_params = { + "nBits": 1024, + "isQuery": True, + # 'resetVect': True, #TODO: this doesn't change the FP + "bitFlags": 32767, + } + assert_transformer_set_params( + AvalonFingerprintTransformer, new_params, chiral_mols_list + ) + + +def test_transform_with_safe_inference_mode( + mols_with_invalid_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: + t.set_params(safe_inference_mode=True) + print(type(t)) + fps = t.transform(mols_with_invalid_container) + + assert len(fps) == len(mols_with_invalid_container) + + # Check that the last row (corresponding to the InvalidMol) contains NaNs + assert np.all(fps.mask[-1]) + + # Check that other rows don't contain NaNs + assert not np.any(fps.mask[:-1]) + + +def test_transform_without_safe_inference_mode( + mols_with_invalid_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, + # MHFP seem to accept invalid mols and return 0,0,0,0's +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: + t.set_params(safe_inference_mode=False) + with pytest.raises( + Exception + ): # You might want to be more specific about the exception type + print(f"testing {type(t)}") + t.transform(mols_with_invalid_container) + + +# Add this test to check parallel processing with error handling +def test_transform_parallel_with_safe_inference_mode( + mols_with_invalid_container, + morgan_transformer, + rdkit_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + secfp_transformer, + avalon_transformer, +): + for t in [ + morgan_transformer, + atompair_transformer, + topologicaltorsion_transformer, + maccs_transformer, + rdkit_transformer, + secfp_transformer, + avalon_transformer, + ]: + t.set_params(safe_inference_mode=True, parallel=True) + fps = t.transform(mols_with_invalid_container) + + assert len(fps) == len(mols_with_invalid_container) + + print(fps.mask) + # Check that the last row (corresponding to the InvalidMol) is masked + assert np.all( + fps.mask[-1] + ) # Mask should be true for all elements in the last row + + # Check that other rows don't contain any masked values + assert not np.any(fps.mask[:-1, :]) diff --git a/tests/test_safeinferencemode.py b/tests/test_safeinferencemode.py new file mode 100644 index 0000000..921cc0f --- /dev/null +++ b/tests/test_safeinferencemode.py @@ -0,0 +1,115 @@ +import pytest +import numpy as np +import pandas as pd +from sklearn.pipeline import Pipeline +from sklearn.ensemble import RandomForestRegressor +from scikit_mol.conversions import SmilesToMolTransformer +from scikit_mol.fingerprints import MorganFingerprintTransformer +from scikit_mol.safeinference import SafeInferenceWrapper +from scikit_mol.utilities import set_safe_inference_mode + +from fixtures import ( + SLC6A4_subset, + invalid_smiles_list, + skip_pandas_output_test, + smiles_list, +) + + +@pytest.fixture +def smiles_pipeline(): + return Pipeline( + [ + ("s2m", SmilesToMolTransformer()), + ("FP", MorganFingerprintTransformer()), + ( + "RF", + SafeInferenceWrapper( + RandomForestRegressor(n_estimators=3, random_state=42) + ), + ), + ] + ) + + +def test_safeinference_wrapper_basic(smiles_pipeline, SLC6A4_subset): + X_smiles, Y = SLC6A4_subset.SMILES, SLC6A4_subset.pXC50 + X_smiles = X_smiles.to_frame() + + # Set safe inference mode + set_safe_inference_mode(smiles_pipeline, True) + + # Train the model + smiles_pipeline.fit(X_smiles, Y) + + # Test prediction + predictions = smiles_pipeline.predict(X_smiles) + assert len(predictions) == len(X_smiles) + assert not np.any(np.isnan(predictions)) + + +def test_safeinference_wrapper_with_invalid_smiles( + smiles_pipeline, SLC6A4_subset, invalid_smiles_list +): + X_smiles, Y = SLC6A4_subset.SMILES[:100], SLC6A4_subset.pXC50[:100] + X_smiles = X_smiles.to_frame() + + # Set safe inference mode + set_safe_inference_mode(smiles_pipeline, True) + + # Train the model + smiles_pipeline.fit(X_smiles, Y) + + # Create a test set with invalid SMILES + X_test = pd.DataFrame({"SMILES": X_smiles["SMILES"].tolist() + invalid_smiles_list}) + + # Test prediction with invalid SMILES + predictions = smiles_pipeline.predict(X_test) + assert len(predictions) == len(X_test) + assert np.any(np.isnan(predictions)) + assert np.all(np.isnan(predictions[-1])) # Only last should be nan + assert np.all(~np.isnan(predictions[:-1])) # All others should not be nan + + +def test_safeinference_wrapper_without_safe_mode( + smiles_pipeline, SLC6A4_subset, invalid_smiles_list +): + X_smiles, Y = SLC6A4_subset.SMILES[:100], SLC6A4_subset.pXC50[:100] + X_smiles = X_smiles.to_frame() + + # Ensure safe inference mode is off (default behavior) + set_safe_inference_mode(smiles_pipeline, False) + + # Train the model + smiles_pipeline.fit(X_smiles, Y) + + # Create a test set with invalid SMILES + X_test = pd.DataFrame({"SMILES": X_smiles["SMILES"].tolist() + invalid_smiles_list}) + + # Test prediction with invalid SMILES + with pytest.raises(Exception): + smiles_pipeline.predict(X_test) + + +@skip_pandas_output_test +def test_safeinference_wrapper_pandas_output( + smiles_pipeline, SLC6A4_subset, pandas_output +): + X_smiles = SLC6A4_subset.SMILES[:100].to_frame() + + # Set safe inference mode + set_safe_inference_mode(smiles_pipeline, True) + + # Fit and transform (up to the FP step) + result = smiles_pipeline[:-1].fit_transform(X_smiles) + assert isinstance(result, pd.DataFrame) + assert result.shape[0] == len(X_smiles) + assert result.shape[1] == smiles_pipeline.named_steps["FP"].nBits + + +@skip_pandas_output_test +def test_safeinference_wrapper_get_feature_names_out(smiles_pipeline): + # Get feature names from the FP step + feature_names = smiles_pipeline.named_steps["FP"].get_feature_names_out() + assert len(feature_names) == smiles_pipeline.named_steps["FP"].nBits + assert all(isinstance(name, str) for name in feature_names) diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py index def82a0..f7193c2 100644 --- a/tests/test_sanitizer.py +++ b/tests/test_sanitizer.py @@ -5,43 +5,62 @@ from fixtures import smiles_list, invalid_smiles_list from scikit_mol.utilities import CheckSmilesSanitazion + @pytest.fixture def sanitizer(): return CheckSmilesSanitazion() + @pytest.fixture def return_mol_sanitizer(): return CheckSmilesSanitazion(return_mol=True) + def test_checksmilessanitation(smiles_list, invalid_smiles_list, sanitizer): smiles_list_sanitized, errors = sanitizer.sanitize(invalid_smiles_list) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] + def test_checksmilessanitation_x_and_y(smiles_list, invalid_smiles_list, sanitizer): - smiles_list_sanitized, y_sanitized, errors, y_errors = sanitizer.sanitize(smiles_list, list(range(len(smiles_list)))) + smiles_list_sanitized, y_sanitized, errors, y_errors = sanitizer.sanitize( + invalid_smiles_list, list(range(len(invalid_smiles_list))) + ) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] - #Test that y is correctly split into y_error and the rest - assert all([ a == b for a, b in zip(y_sanitized, list(range(len(smiles_list) -1 )))]) - assert y_errors[0] == len(smiles_list)-1 #Last smiles is invalid + # Test that y is correctly split into y_error and the rest + assert all([a == b for a, b in zip(y_sanitized, list(range(len(smiles_list) - 1)))]) + assert y_errors[0] == len(invalid_smiles_list) - 1 # Last smiles is invalid + def test_checksmilessanitation_np(smiles_list, invalid_smiles_list, sanitizer): smiles_list_sanitized, errors = sanitizer.sanitize(np.array(invalid_smiles_list)) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] + def test_checksmilessanitation_numpy(smiles_list, invalid_smiles_list, sanitizer): smiles_list_sanitized, errors = sanitizer.sanitize(pd.Series(invalid_smiles_list)) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) + assert all([a == b for a, b in zip(smiles_list, smiles_list_sanitized)]) assert errors[0] == sanitizer.errors.SMILES[0] -def test_checksmilessanitation_return_mol(smiles_list, invalid_smiles_list, return_mol_sanitizer): + +def test_checksmilessanitation_return_mol( + smiles_list, invalid_smiles_list, return_mol_sanitizer +): smiles_list_sanitized, errors = return_mol_sanitizer.sanitize(invalid_smiles_list) assert len(invalid_smiles_list) > len(smiles_list_sanitized) - assert all([ a == b for a, b in zip(smiles_list, [Chem.MolToSmiles(smiles) for smiles in smiles_list_sanitized])]) - assert errors[0] == return_mol_sanitizer.errors.SMILES[0] \ No newline at end of file + assert all( + [ + a == b + for a, b in zip( + smiles_list, + [Chem.MolToSmiles(smiles) for smiles in smiles_list_sanitized], + ) + ] + ) + assert errors[0] == return_mol_sanitizer.errors.SMILES[0] diff --git a/tests/test_smilestomol.py b/tests/test_smilestomol.py index e01af52..19bf288 100644 --- a/tests/test_smilestomol.py +++ b/tests/test_smilestomol.py @@ -6,33 +6,57 @@ from rdkit import Chem import sklearn from scikit_mol.conversions import SmilesToMolTransformer -from scikit_mol.core import SKLEARN_VERSION_PANDAS_OUT, DEFAULT_MOL_COLUMN_NAME -from fixtures import smiles_list, invalid_smiles_list, smiles_container, skip_pandas_output_test +from scikit_mol.core import ( + SKLEARN_VERSION_PANDAS_OUT, + DEFAULT_MOL_COLUMN_NAME, + InvalidMol, +) +from fixtures import ( + smiles_list, + invalid_smiles_list, + smiles_container, + skip_pandas_output_test, +) @pytest.fixture def smilestomol_transformer(): return SmilesToMolTransformer() + def test_smilestomol(smiles_container, smilestomol_transformer): - result_mols = smilestomol_transformer.transform(smiles_container) - result_smiles = [Chem.MolToSmiles(mol) for mol in result_mols.flatten()] - if isinstance(smiles_container, pd.DataFrame): - expected_smiles = smiles_container.iloc[:, 0].tolist() - else: - expected_smiles = smiles_container - assert all([ a == b for a, b in zip(expected_smiles, result_smiles)]) + result_mols = smilestomol_transformer.transform(smiles_container) + result_smiles = [Chem.MolToSmiles(mol) for mol in result_mols.flatten()] + if isinstance(smiles_container, pd.DataFrame): + expected_smiles = smiles_container.iloc[:, 0].tolist() + else: + expected_smiles = smiles_container + assert all([a == b for a, b in zip(expected_smiles, result_smiles)]) + + +def test_smilestomol_transform(smilestomol_transformer, smiles_container): + result = smilestomol_transformer.transform(smiles_container) + assert len(result) == len(smiles_container) + assert all(isinstance(mol, Chem.Mol) for mol in result.flatten()) + + +def test_smilestomol_fit(smilestomol_transformer, smiles_container): + result = smilestomol_transformer.fit(smiles_container) + assert result == smilestomol_transformer + def test_smilestomol_clone(smilestomol_transformer): t2 = clone(smilestomol_transformer) - params = smilestomol_transformer.get_params() + params = smilestomol_transformer.get_params() params_2 = t2.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) + def test_smilestomol_unsanitzable(invalid_smiles_list, smilestomol_transformer): with pytest.raises(ValueError): smilestomol_transformer.transform(invalid_smiles_list) + def test_descriptor_transformer_parallel(smiles_container, smilestomol_transformer): smilestomol_transformer.set_params(parallel=True) mol_list = smilestomol_transformer.transform(smiles_container) @@ -40,11 +64,104 @@ def test_descriptor_transformer_parallel(smiles_container, smilestomol_transform expected_smiles = smiles_container.iloc[:, 0].tolist() else: expected_smiles = smiles_container - assert all([ a == b for a, b in zip(expected_smiles, [Chem.MolToSmiles(mol) for mol in mol_list.flatten()])]) + assert all( + [ + a == b + for a, b in zip( + expected_smiles, [Chem.MolToSmiles(mol) for mol in mol_list.flatten()] + ) + ] + ) + + +def test_smilestomol_inverse_transform(smilestomol_transformer, smiles_container): + mols = smilestomol_transformer.transform(smiles_container) + result = smilestomol_transformer.inverse_transform(mols) + assert len(result) == len(smiles_container) + assert all(isinstance(smiles, str) for smiles in result.flatten()) + + +def test_smilestomol_inverse_transform_with_invalid( + invalid_smiles_list, smilestomol_transformer +): + smilestomol_transformer.set_params(safe_inference_mode=True) + + # Forward transform + mols = smilestomol_transformer.transform(invalid_smiles_list) + + # Inverse transform + result = smilestomol_transformer.inverse_transform(mols) + + assert len(result) == len(invalid_smiles_list) + + # Check that all but the last element are the same as the original SMILES + for original, res in zip(invalid_smiles_list[:-1], result[:-1].flatten()): + assert isinstance(res, str) + assert original == res + + # Check that the last element is an InvalidMol instance + assert isinstance(result[-1].item(), InvalidMol) + assert "Invalid SMILES" in result[-1].item().error + assert invalid_smiles_list[-1] in result[-1].item().error + + +def test_smilestomol_get_feature_names_out(smilestomol_transformer): + feature_names = smilestomol_transformer.get_feature_names_out() + assert feature_names == [DEFAULT_MOL_COLUMN_NAME] + + +def test_smilestomol_safe_inference(invalid_smiles_list, smilestomol_transformer): + smilestomol_transformer.set_params(safe_inference_mode=True) + result = smilestomol_transformer.transform(invalid_smiles_list) + + assert len(result) == len(invalid_smiles_list) + assert isinstance(result, np.ndarray) + + # Check that all but the last element are valid RDKit Mol objects + for mol in result[:-1].flatten(): + assert isinstance(mol, Chem.Mol) + assert mol is not None + + # Check that the last element is an InvalidMol instance + last_mol = result[-1].item() + assert isinstance(last_mol, InvalidMol) + + # Check if the error message is correctly set for the invalid SMILES + assert "Invalid SMILES" in last_mol.error + assert invalid_smiles_list[-1] in last_mol.error + + +@pytest.mark.skipif( + not skip_pandas_output_test, + reason="Pandas output not supported in this sklearn version", +) +def test_smilestomol_safe_inference_pandas_output( + invalid_smiles_list, smilestomol_transformer, pandas_output +): + smilestomol_transformer.set_params(safe_inference_mode=True) + result = smilestomol_transformer.transform(invalid_smiles_list) + + assert len(result) == len(invalid_smiles_list) + assert isinstance(result, pd.DataFrame) + assert result.columns == [DEFAULT_MOL_COLUMN_NAME] + + # Check that all but the last element are valid RDKit Mol objects + for mol in result[DEFAULT_MOL_COLUMN_NAME][:-1]: + assert isinstance(mol, Chem.Mol) + assert mol is not None + + # Check that the last element is an InvalidMol instance + last_mol = result[DEFAULT_MOL_COLUMN_NAME].iloc[-1] + assert isinstance(last_mol, InvalidMol) + + # Check if the error message is correctly set for the invalid SMILES + assert "Invalid SMILES" in last_mol.error + assert invalid_smiles_list[-1] in last_mol.error + @skip_pandas_output_test def test_pandas_output(smiles_container, smilestomol_transformer, pandas_output): - mols = smilestomol_transformer.transform(smiles_container) - assert isinstance(mols, pd.DataFrame) - assert mols.shape[0] == len(smiles_container) - assert mols.columns.tolist() == [DEFAULT_MOL_COLUMN_NAME] \ No newline at end of file + mols = smilestomol_transformer.transform(smiles_container) + assert isinstance(mols, pd.DataFrame) + assert mols.shape[0] == len(smiles_container) + assert mols.columns.tolist() == [DEFAULT_MOL_COLUMN_NAME]