From ff8cf2eea6199d60ec4bee4e893d114d8d7add8c Mon Sep 17 00:00:00 2001 From: Esben Jannik Bjerrum Date: Fri, 22 Nov 2024 16:43:11 +0100 Subject: [PATCH] Updated child classes to honor the safe_inference_mode --- scikit_mol/fingerprints/atompair.py | 4 ++-- scikit_mol/fingerprints/baseclasses.py | 1 - scikit_mol/fingerprints/minhash.py | 2 ++ scikit_mol/fingerprints/morgan.py | 7 ++++++- scikit_mol/fingerprints/rdkitfp.py | 3 ++- scikit_mol/fingerprints/topologicaltorsion.py | 3 ++- 6 files changed, 14 insertions(+), 6 deletions(-) diff --git a/scikit_mol/fingerprints/atompair.py b/scikit_mol/fingerprints/atompair.py index aff8f9f..2198afd 100644 --- a/scikit_mol/fingerprints/atompair.py +++ b/scikit_mol/fingerprints/atompair.py @@ -100,9 +100,10 @@ def __init__( fpSize: int = 2048, useCounts: bool = False, parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, ): self._initializing = True - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.fpSize = fpSize self.use2D = use2D self.includeChirality = includeChirality @@ -114,7 +115,6 @@ def __init__( self.fromAtoms = fromAtoms self.ignoreAtoms = ignoreAtoms self.atomInvariants = atomInvariants - self._generate_fp_generator() delattr(self, "_initializing") diff --git a/scikit_mol/fingerprints/baseclasses.py b/scikit_mol/fingerprints/baseclasses.py index 03ca11b..d1eef40 100644 --- a/scikit_mol/fingerprints/baseclasses.py +++ b/scikit_mol/fingerprints/baseclasses.py @@ -127,7 +127,6 @@ def _transform(self, X): if self.safe_inference_mode: # Use the new method with masked arrays if we're in safe inference mode arrays = [self._safe_transform_mol(mol) for mol in X] - print(arrays) return np.ma.stack(arrays) elif hasattr( self, "dtype" diff --git a/scikit_mol/fingerprints/minhash.py b/scikit_mol/fingerprints/minhash.py index 1c7e62a..9d0ec31 100644 --- a/scikit_mol/fingerprints/minhash.py +++ b/scikit_mol/fingerprints/minhash.py @@ -9,6 +9,7 @@ from rdkit.Chem import rdMHFPFingerprint +# TODO move to use FpsGeneratorTransformer class MHFingerprintTransformer(FpsTransformer): def __init__( self, @@ -105,6 +106,7 @@ def n_permutations(self, n_permutations): self._recreate_encoder() +# TODO use FpsGeneratorTransformer instead class SECFingerprintTransformer(FpsTransformer): # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8 def __init__( diff --git a/scikit_mol/fingerprints/morgan.py b/scikit_mol/fingerprints/morgan.py index 37d7cf8..f7d6067 100644 --- a/scikit_mol/fingerprints/morgan.py +++ b/scikit_mol/fingerprints/morgan.py @@ -98,6 +98,7 @@ def __init__( useFeatures=False, useCounts=False, parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, ): """Transform RDKit mols into Count or bit-based hashed MorganFingerprints @@ -115,10 +116,14 @@ def __init__( use chemical features, rather than atom-type in calculation of the fingerprint keys, by default False useCounts : bool, optional If toggled will create the count and not bit-based fingerprint, by default False + parallel : bool or int, optional + If True, will use all available cores, if int will use that many cores, by default False + safe_inference_mode : bool, optional + If True, will return masked arrays for invalid mols, by default False """ self._initializing = True - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.fpSize = fpSize self.radius = radius self.useChirality = useChirality diff --git a/scikit_mol/fingerprints/rdkitfp.py b/scikit_mol/fingerprints/rdkitfp.py index 28ce0a8..ad87a26 100644 --- a/scikit_mol/fingerprints/rdkitfp.py +++ b/scikit_mol/fingerprints/rdkitfp.py @@ -114,6 +114,7 @@ def __init__( numBitsPerFeature: int = 2, useCounts: bool = False, parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, ): """Calculates the RDKit fingerprints @@ -139,7 +140,7 @@ def __init__( the number of bits set per path/subgraph found, by default 2 """ self._initializing = True - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.minPath = minPath self.maxPath = maxPath self.useHs = useHs diff --git a/scikit_mol/fingerprints/topologicaltorsion.py b/scikit_mol/fingerprints/topologicaltorsion.py index 0b6640d..63b68bf 100644 --- a/scikit_mol/fingerprints/topologicaltorsion.py +++ b/scikit_mol/fingerprints/topologicaltorsion.py @@ -80,9 +80,10 @@ def __init__( fpSize: int = 2048, useCounts: bool = False, parallel: Union[bool, int] = False, + safe_inference_mode: bool = False, ): self._initializing = True - super().__init__(parallel=parallel) + super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode) self.fpSize = fpSize self.includeChirality = includeChirality self.targetSize = targetSize