diff --git a/scikit_mol/fingerprints/baseclasses.py b/scikit_mol/fingerprints/baseclasses.py index ce28e18..03ca11b 100644 --- a/scikit_mol/fingerprints/baseclasses.py +++ b/scikit_mol/fingerprints/baseclasses.py @@ -36,18 +36,16 @@ ) -class FpsTransformer(ABC, BaseEstimator, TransformerMixin): +class BaseFpsTransformer(ABC, BaseEstimator, TransformerMixin): def __init__( self, parallel: Union[bool, int] = False, start_method: str = None, safe_inference_mode: bool = False, - dtype: np.dtype = np.int8, ): self.parallel = parallel self.start_method = start_method self.safe_inference_mode = safe_inference_mode - self.dtype = dtype @property def nBits(self): @@ -98,34 +96,25 @@ def get_feature_names_out(self, input_features=None): prefix = self._get_column_prefix() return np.array([f"{prefix}_{i}" for i in range(1, self.fpSize + 1)]) - @abstractmethod - def _mol2fp(self, mol): - """Generate fingerprint from mol - - MUST BE OVERWRITTEN - """ - raise NotImplementedError("_mol2fp not implemented") - - def _fp2array(self, fp): - if fp: - arr = np.zeros((self.fpSize,), dtype=self.dtype) - DataStructs.ConvertToNumpyArray(fp, arr) - return arr - else: - return np.ma.masked_all((self.fpSize,), dtype=self.dtype) - - def _transform_mol(self, mol): + def _safe_transform_mol(self, mol): + """Handle safe inference mode with masked arrays""" if not mol and self.safe_inference_mode: - return self._fp2array(False) + return np.ma.masked_all(self.fpSize) + try: - fp = self._mol2fp(mol) - return self._fp2array(fp) + result = self._transform_mol(mol) + return result except Exception as e: if self.safe_inference_mode: - return self._fp2array(False) + return np.ma.masked_all(self.fpSize) else: raise e + @abstractmethod + def _transform_mol(self, mol): + """Transform a single molecule to numpy array""" + raise NotImplementedError + def fit(self, X, y=None): """Included for scikit-learn compatibility @@ -137,15 +126,20 @@ def fit(self, X, y=None): def _transform(self, X): if self.safe_inference_mode: # Use the new method with masked arrays if we're in safe inference mode - arrays = [self._transform_mol(mol) for mol in X] + arrays = [self._safe_transform_mol(mol) for mol in X] print(arrays) return np.ma.stack(arrays) - else: + elif hasattr( + self, "dtype" + ): # TODO, it seems a bit of a code smell that we have to preemptively test a property from the baseclass? # Use the original, faster method if we're not in safe inference mode arr = np.zeros((len(X), self.fpSize), dtype=self.dtype) for i, mol in enumerate(X): arr[i, :] = self._transform_mol(mol) return arr + else: # We are unsure on the dtype, so we don't use a preassigned array #TODO test time differnece to previous + arrays = [self._transform_mol(mol) for mol in X] + return np.stack(arrays) def _transform_sparse(self, X): arr = np.zeros((len(X), self.fpSize), dtype=self.dtype) @@ -202,20 +196,49 @@ def transform(self, X, y=None): return arr -class FpsGeneratorTransformer(FpsTransformer): - _regenerate_on_properties = () +class FpsTransformer(BaseFpsTransformer): + """Classic fingerprint transformer using mol2fp pattern""" - def _fp2array(self, fp): - raise DeprecationWarning("Generators can directly return fingerprints") + def __init__( + self, + parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, + dtype: np.dtype = np.int8, + ): + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) + self.dtype = dtype + def _transform_mol(self, mol): + """Implements the mol -> rdkit fingerprint data structure -> numpy array pattern""" + fp = self._mol2fp(mol) + return self._fp2array(fp) + + @abstractmethod def _mol2fp(self, mol): - raise DeprecationWarning("use _mol2array") + """Generate fingerprint from mol + + MUST BE OVERWRITTEN + """ + raise NotImplementedError("_mol2fp not implemented") + + def _fp2array(self, fp): + """Convert RDKit fingerprint data structure to numpy array""" + if fp: + arr = np.zeros((self.fpSize,), dtype=self.dtype) + DataStructs.ConvertToNumpyArray(fp, arr) + return arr + else: + return np.ma.masked_all((self.fpSize,), dtype=self.dtype) + + +class FpsGeneratorTransformer(BaseFpsTransformer): + _regenerate_on_properties = () def __getstate__(self): # Get the state of the parent class state = super().__getstate__() state.update(self.get_params()) - # Remove the unpicklable property from the state + # Remove the potentiallyunpicklable property from the state state.pop("_fpgen", None) # fpgen is not picklable return state @@ -234,6 +257,8 @@ def __setstate__(self, state): ] self._generate_fp_generator() + # TODO: overload set_params in order to not make multiple calls to _generate_fp_generator + def __setattr__(self, name: str, value): super().__setattr__(name, value) if ( diff --git a/tests/test_fptransformersgenerator.py b/tests/test_fptransformersgenerator.py index 81da19c..c61f6e9 100644 --- a/tests/test_fptransformersgenerator.py +++ b/tests/test_fptransformersgenerator.py @@ -2,25 +2,38 @@ import tempfile import pytest import numpy as np -from fixtures import mols_list, smiles_list, mols_container, smiles_container, fingerprint, chiral_smiles_list, chiral_mols_list +from fixtures import ( + mols_list, + smiles_list, + mols_container, + smiles_container, + fingerprint, + chiral_smiles_list, + chiral_mols_list, +) from sklearn import clone -from scikit_mol.fingerprints import (MorganFPGeneratorTransformer, - RDKitFPGeneratorTransformer, - AtomPairFPGeneratorTransformer, - TopologicalTorsionFPGeneatorTransformer, - ) +from scikit_mol.fingerprints import ( + MorganFPGeneratorTransformer, + RDKitFPGeneratorTransformer, + AtomPairFPGeneratorTransformer, + TopologicalTorsionFPGeneatorTransformer, +) -test_transformers = [MorganFPGeneratorTransformer, RDKitFPGeneratorTransformer, - AtomPairFPGeneratorTransformer, TopologicalTorsionFPGeneatorTransformer] +test_transformers = [ + MorganFPGeneratorTransformer, + RDKitFPGeneratorTransformer, + AtomPairFPGeneratorTransformer, + TopologicalTorsionFPGeneatorTransformer, +] -@pytest.mark.parametrize("transformer_class", test_transformers) -def test_fpstransformer_fp2array(transformer_class, fingerprint): - transformer = transformer_class() +# @pytest.mark.parametrize("transformer_class", test_transformers) +# def test_fpstransformer_fp2array(transformer_class, fingerprint): +# transformer = transformer_class() - with pytest.raises(DeprecationWarning, match='Generators can directly return fingerprints'): - fp = transformer._fp2array(fingerprint) +# with pytest.raises(DeprecationWarning, match='Generators can directly return fingerprints'): +# fp = transformer._fp2array(fingerprint) @pytest.mark.parametrize("transformer_class", test_transformers) @@ -28,75 +41,79 @@ def test_fpstransformer_transform_mol(transformer_class, mols_list): transformer = transformer_class() fp = transformer._transform_mol(mols_list[0]) - #See that fp is the correct type, shape and bit count - assert(type(fp) == type(np.array([0]))) - assert(fp.shape == (2048,)) + # See that fp is the correct type, shape and bit count + assert type(fp) == type(np.array([0])) + assert fp.shape == (2048,) if isinstance(transformer, RDKitFPGeneratorTransformer): - assert(fp.sum() == 104) + assert fp.sum() == 104 elif isinstance(transformer, AtomPairFPGeneratorTransformer): - assert (fp.sum() == 32) + assert fp.sum() == 32 elif isinstance(transformer, TopologicalTorsionFPGeneatorTransformer): - assert (fp.sum() == 12) + assert fp.sum() == 12 elif isinstance(transformer, MorganFPGeneratorTransformer): - assert (fp.sum() == 14) + assert fp.sum() == 14 else: raise NotImplementedError("missing Assert") + @pytest.mark.parametrize("transformer_class", test_transformers) def test_clonability(transformer_class): transformer = transformer_class() - params = transformer.get_params() + params = transformer.get_params() t2 = clone(transformer) params_2 = t2.get_params() - #Parameters of cloned transformers should be the same - assert all([ params[key] == params_2[key] for key in params.keys()]) - #Cloned transformers should not be the same object + # Parameters of cloned transformers should be the same + assert all([params[key] == params_2[key] for key in params.keys()]) + # Cloned transformers should not be the same object assert t2 != transformer + @pytest.mark.parametrize("transformer_class", test_transformers) def test_set_params(transformer_class): transformer = transformer_class() - params = transformer.get_params() - #change extracted dictionary - params['fpSize'] = 4242 - #change params in transformer - transformer.set_params(fpSize = 4242) + params = transformer.get_params() + # change extracted dictionary + params["fpSize"] = 4242 + # change params in transformer + transformer.set_params(fpSize=4242) # get parameters as dictionary and assert that it is the same params_2 = transformer.get_params() - assert all([ params[key] == params_2[key] for key in params.keys()]) + assert all([params[key] == params_2[key] for key in params.keys()]) + @pytest.mark.parametrize("transformer_class", test_transformers) def test_transform(mols_container, transformer_class): transformer = transformer_class() - #Test the different transformers - params = transformer.get_params() + # Test the different transformers + params = transformer.get_params() fps = transformer.transform(mols_container) - #Assert that the same length of input and output + # Assert that the same length of input and output assert len(fps) == len(mols_container) - fpsize = params['fpSize'] + fpsize = params["fpSize"] assert len(fps[0]) == fpsize + @pytest.mark.parametrize("transformer_class", test_transformers) def test_transform_parallel(mols_container, transformer_class): transformer = transformer_class() - #Test the different transformers + # Test the different transformers transformer.set_params(parallel=True) - params = transformer.get_params() + params = transformer.get_params() fps = transformer.transform(mols_container) - #Assert that the same length of input and output + # Assert that the same length of input and output assert len(fps) == len(mols_container) - fpsize = params['fpSize'] + fpsize = params["fpSize"] assert len(fps[0]) == fpsize @pytest.mark.parametrize("transformer_class", test_transformers) def test_picklable(transformer_class): - #Test the different transformers + # Test the different transformers transformer = transformer_class() p = transformer.get_params() @@ -107,8 +124,8 @@ def test_picklable(transformer_class): print(p) print(vars(transformer)) print(vars(t2)) - assert(transformer.get_params() == t2.get_params()) - + assert transformer.get_params() == t2.get_params() + @pytest.mark.parametrize("transfomer", test_transformers) def assert_transformer_set_params(transfomer, new_params, mols_list): @@ -128,20 +145,36 @@ def assert_transformer_set_params(transfomer, new_params, mols_list): # Now fp_default should not be the same as fp_reset_params - assert ~np.all([np.array_equal(fp_default, fp_reset_params) for fp_default, fp_reset_params in zip(fps_default, fps_reset_params)]), f"Assertation error, FP appears the same, although the {key} should be changed from {default_params[key]} to {params[key]}" + assert ~np.all( + [ + np.array_equal(fp_default, fp_reset_params) + for fp_default, fp_reset_params in zip(fps_default, fps_reset_params) + ] + ), f"Assertation error, FP appears the same, although the {key} should be changed from {default_params[key]} to {params[key]}" # fp_reset_params and fp_init_new_params should however be the same - assert np.all([np.array_equal(fp_init_new_params, fp_reset_params) for fp_init_new_params, fp_reset_params in zip(fps_init_new_params, fps_reset_params)]) , f"Assertation error, FP appears to be different, although the {key} should be changed back as well as initialized to {params[key]}" + assert np.all( + [ + np.array_equal(fp_init_new_params, fp_reset_params) + for fp_init_new_params, fp_reset_params in zip( + fps_init_new_params, fps_reset_params + ) + ] + ), f"Assertation error, FP appears to be different, although the {key} should be changed back as well as initialized to {params[key]}" def test_morgan_set_params(chiral_mols_list): - new_params = {'fpSize': 1024, - 'radius': 1, - 'useBondTypes': False,# TODO, why doesn't this change the FP? - 'useChirality': True, - 'useCounts': True, - 'useFeatures': True} - - assert_transformer_set_params(MorganFPGeneratorTransformer, new_params, chiral_mols_list) + new_params = { + "fpSize": 1024, + "radius": 1, + "useBondTypes": False, # TODO, why doesn't this change the FP? + "useChirality": True, + "useCounts": True, + "useFeatures": True, + } + + assert_transformer_set_params( + MorganFPGeneratorTransformer, new_params, chiral_mols_list + ) def test_atompairs_set_params(chiral_mols_list): @@ -150,39 +183,48 @@ def test_atompairs_set_params(chiral_mols_list): #'confId': -1, #'fromAtoms': 1, #'ignoreAtoms': 0, - 'includeChirality': True, - 'maxLength': 3, - 'minLength': 3, - 'fpSize': 1024, + "includeChirality": True, + "maxLength": 3, + "minLength": 3, + "fpSize": 1024, #'nBitsPerEntry': 3, #Todo: not setable with the generators? #'use2D': True, #TODO, understand why this can't be set different - 'useCounts': True} - - assert_transformer_set_params(AtomPairFPGeneratorTransformer, new_params, chiral_mols_list) + "useCounts": True, + } + + assert_transformer_set_params( + AtomPairFPGeneratorTransformer, new_params, chiral_mols_list + ) def test_topologicaltorsion_set_params(chiral_mols_list): - new_params = {#'atomInvariants': 0, - #'fromAtoms': 0, - #'ignoreAtoms': 0, - #'includeChirality': True, #TODO, figure out why this setting seems to give same FP wheter toggled or not - 'fpSize': 1024, - #'nBitsPerEntry': 3, #Todo: not setable with the generators? - 'targetSize': 5, - 'useCounts': True} - - assert_transformer_set_params(TopologicalTorsionFPGeneatorTransformer, new_params, chiral_mols_list) + new_params = { #'atomInvariants': 0, + #'fromAtoms': 0, + #'ignoreAtoms': 0, + #'includeChirality': True, #TODO, figure out why this setting seems to give same FP wheter toggled or not + "fpSize": 1024, + #'nBitsPerEntry': 3, #Todo: not setable with the generators? + "targetSize": 5, + "useCounts": True, + } + + assert_transformer_set_params( + TopologicalTorsionFPGeneatorTransformer, new_params, chiral_mols_list + ) + def test_RDKitFPTransformer(chiral_mols_list): - new_params = {#'atomInvariantsGenerator': None, - #'branchedPaths': False, - #'countBounds': 0, #TODO: What does this do? - 'countSimulation': True, - 'fpSize': 1024, - 'maxPath': 3, - 'minPath': 2, - 'numBitsPerFeature': 3, - 'useBondOrder': False, #TODO, why doesn't this change the FP? - #'useHs': False, #TODO, why doesn't this change the FP? - } - assert_transformer_set_params(RDKitFPGeneratorTransformer, new_params, chiral_mols_list) + new_params = { #'atomInvariantsGenerator': None, + #'branchedPaths': False, + #'countBounds': 0, #TODO: What does this do? + "countSimulation": True, + "fpSize": 1024, + "maxPath": 3, + "minPath": 2, + "numBitsPerFeature": 3, + "useBondOrder": False, # TODO, why doesn't this change the FP? + #'useHs': False, #TODO, why doesn't this change the FP? + } + assert_transformer_set_params( + RDKitFPGeneratorTransformer, new_params, chiral_mols_list + )