Skip to content

Commit

Permalink
Updated child classes to honor the safe_inference_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Esben Jannik Bjerrum committed Nov 22, 2024
1 parent f092feb commit ff8cf2e
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions scikit_mol/fingerprints/atompair.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def __init__(
fpSize: int = 2048,
useCounts: bool = False,
parallel: Union[bool, int] = False,
safe_inference_mode: bool = False,
):
self._initializing = True
super().__init__(parallel=parallel)
super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode)
self.fpSize = fpSize
self.use2D = use2D
self.includeChirality = includeChirality
Expand All @@ -114,7 +115,6 @@ def __init__(
self.fromAtoms = fromAtoms
self.ignoreAtoms = ignoreAtoms
self.atomInvariants = atomInvariants

self._generate_fp_generator()
delattr(self, "_initializing")

Expand Down
1 change: 0 additions & 1 deletion scikit_mol/fingerprints/baseclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def _transform(self, X):
if self.safe_inference_mode:
# Use the new method with masked arrays if we're in safe inference mode
arrays = [self._safe_transform_mol(mol) for mol in X]
print(arrays)
return np.ma.stack(arrays)
elif hasattr(
self, "dtype"
Expand Down
2 changes: 2 additions & 0 deletions scikit_mol/fingerprints/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rdkit.Chem import rdMHFPFingerprint


# TODO move to use FpsGeneratorTransformer
class MHFingerprintTransformer(FpsTransformer):
def __init__(
self,
Expand Down Expand Up @@ -105,6 +106,7 @@ def n_permutations(self, n_permutations):
self._recreate_encoder()


# TODO use FpsGeneratorTransformer instead
class SECFingerprintTransformer(FpsTransformer):
# https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0321-8
def __init__(
Expand Down
7 changes: 6 additions & 1 deletion scikit_mol/fingerprints/morgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
useFeatures=False,
useCounts=False,
parallel: Union[bool, int] = False,
safe_inference_mode: bool = False,
):
"""Transform RDKit mols into Count or bit-based hashed MorganFingerprints
Expand All @@ -115,10 +116,14 @@ def __init__(
use chemical features, rather than atom-type in calculation of the fingerprint keys, by default False
useCounts : bool, optional
If toggled will create the count and not bit-based fingerprint, by default False
parallel : bool or int, optional
If True, will use all available cores, if int will use that many cores, by default False
safe_inference_mode : bool, optional
If True, will return masked arrays for invalid mols, by default False
"""

self._initializing = True
super().__init__(parallel=parallel)
super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode)
self.fpSize = fpSize
self.radius = radius
self.useChirality = useChirality
Expand Down
3 changes: 2 additions & 1 deletion scikit_mol/fingerprints/rdkitfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
numBitsPerFeature: int = 2,
useCounts: bool = False,
parallel: Union[bool, int] = False,
safe_inference_mode: bool = False,
):
"""Calculates the RDKit fingerprints
Expand All @@ -139,7 +140,7 @@ def __init__(
the number of bits set per path/subgraph found, by default 2
"""
self._initializing = True
super().__init__(parallel=parallel)
super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode)
self.minPath = minPath
self.maxPath = maxPath
self.useHs = useHs
Expand Down
3 changes: 2 additions & 1 deletion scikit_mol/fingerprints/topologicaltorsion.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def __init__(
fpSize: int = 2048,
useCounts: bool = False,
parallel: Union[bool, int] = False,
safe_inference_mode: bool = False,
):
self._initializing = True
super().__init__(parallel=parallel)
super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode)
self.fpSize = fpSize
self.includeChirality = includeChirality
self.targetSize = targetSize
Expand Down

0 comments on commit ff8cf2e

Please sign in to comment.