-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split fingerprint file into smaller for better overview
- Loading branch information
Esben Jannik Bjerrum
committed
Nov 22, 2024
1 parent
5f91e0c
commit f7b20f1
Showing
9 changed files
with
950 additions
and
747 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .baseclasses import ( | ||
FpsTransformer, | ||
FpsGeneratorTransformer, | ||
) # TODO, for backwards compatibility with tests, needs to be removed | ||
|
||
from .atompair import AtomPairFingerprintTransformer, AtomPairFPGeneratorTransformer | ||
from .avalon import AvalonFingerprintTransformer | ||
from .maccs import MACCSKeysFingerprintTransformer | ||
from .minhash import MHFingerprintTransformer, SECFingerprintTransformer | ||
from .morgan import MorganFingerprintTransformer, MorganFPGeneratorTransformer | ||
from .rdkitfp import RDKitFingerprintTransformer, RDKitFPGeneratorTransformer | ||
from .topologicaltorsion import ( | ||
TopologicalTorsionFingerprintTransformer, | ||
TopologicalTorsionFPGeneatorTransformer, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
from warnings import warn | ||
|
||
from .baseclasses import FpsTransformer, FpsGeneratorTransformer | ||
|
||
from rdkit.Chem.rdFingerprintGenerator import GetAtomPairGenerator | ||
from rdkit.Chem import rdMolDescriptors | ||
|
||
|
||
class AtomPairFingerprintTransformer(FpsTransformer): | ||
def __init__( | ||
self, | ||
minLength: int = 1, | ||
maxLength: int = 30, | ||
fromAtoms=0, | ||
ignoreAtoms=0, | ||
atomInvariants=0, | ||
nBitsPerEntry: int = 4, | ||
includeChirality: bool = False, | ||
use2D: bool = True, | ||
confId: int = -1, | ||
fpSize=2048, | ||
useCounts: bool = False, | ||
parallel: Union[bool, int] = False, | ||
safe_inference_mode: bool = False, | ||
dtype: np.dtype = np.int8, | ||
): | ||
super().__init__( | ||
parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype | ||
) | ||
self.minLength = minLength | ||
self.maxLength = maxLength | ||
self.fromAtoms = fromAtoms | ||
self.ignoreAtoms = ignoreAtoms | ||
self.atomInvariants = atomInvariants | ||
self.includeChirality = includeChirality | ||
self.use2D = use2D | ||
self.confId = confId | ||
self.fpSize = fpSize | ||
self.nBitsPerEntry = nBitsPerEntry | ||
self.useCounts = useCounts | ||
|
||
warn( | ||
"AtomPairFingerprintTransformer will be replace by AtomPairFPGeneratorTransformer, due to changes in RDKit!", | ||
DeprecationWarning, | ||
) | ||
|
||
def _mol2fp(self, mol): | ||
if self.useCounts: | ||
return rdMolDescriptors.GetHashedAtomPairFingerprint( | ||
mol, | ||
nBits=int(self.fpSize), | ||
minLength=int(self.minLength), | ||
maxLength=int(self.maxLength), | ||
fromAtoms=self.fromAtoms, | ||
ignoreAtoms=self.ignoreAtoms, | ||
atomInvariants=self.atomInvariants, | ||
includeChirality=bool(self.includeChirality), | ||
use2D=bool(self.use2D), | ||
confId=int(self.confId), | ||
) | ||
else: | ||
return rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect( | ||
mol, | ||
nBits=int(self.fpSize), | ||
minLength=int(self.minLength), | ||
maxLength=int(self.maxLength), | ||
fromAtoms=self.fromAtoms, | ||
ignoreAtoms=self.ignoreAtoms, | ||
atomInvariants=self.atomInvariants, | ||
nBitsPerEntry=int(self.nBitsPerEntry), | ||
includeChirality=bool(self.includeChirality), | ||
use2D=bool(self.use2D), | ||
confId=int(self.confId), | ||
) | ||
|
||
|
||
class AtomPairFPGeneratorTransformer(FpsGeneratorTransformer): | ||
_regenerate_on_properties = ( | ||
"fpSize", | ||
"includeChirality", | ||
"use2D", | ||
"minLength", | ||
"maxLength", | ||
) | ||
|
||
def __init__( | ||
self, | ||
minLength: int = 1, | ||
maxLength: int = 30, | ||
fromAtoms=None, | ||
ignoreAtoms=None, | ||
atomInvariants=None, | ||
includeChirality: bool = False, | ||
use2D: bool = True, | ||
confId: int = -1, | ||
fpSize: int = 2048, | ||
useCounts: bool = False, | ||
parallel: Union[bool, int] = False, | ||
): | ||
self._initializing = True | ||
super().__init__(parallel=parallel) | ||
self.fpSize = fpSize | ||
self.use2D = use2D | ||
self.includeChirality = includeChirality | ||
self.minLength = minLength | ||
self.maxLength = maxLength | ||
|
||
self.useCounts = useCounts | ||
self.confId = confId | ||
self.fromAtoms = fromAtoms | ||
self.ignoreAtoms = ignoreAtoms | ||
self.atomInvariants = atomInvariants | ||
|
||
self._generate_fp_generator() | ||
delattr(self, "_initializing") | ||
|
||
def _generate_fp_generator(self): | ||
self._fpgen = GetAtomPairGenerator( | ||
minDistance=self.minLength, | ||
maxDistance=self.maxLength, | ||
includeChirality=self.includeChirality, | ||
use2D=self.use2D, | ||
fpSize=self.fpSize, | ||
) | ||
|
||
def _transform_mol(self, mol) -> np.array: | ||
if self.useCounts: | ||
return self._fpgen.GetCountFingerprintAsNumPy( | ||
mol, | ||
fromAtoms=self.fromAtoms, | ||
ignoreAtoms=self.ignoreAtoms, | ||
customAtomInvariants=self.atomInvariants, | ||
) | ||
else: | ||
return self._fpgen.GetFingerprintAsNumPy( | ||
mol, | ||
fromAtoms=self.fromAtoms, | ||
ignoreAtoms=self.ignoreAtoms, | ||
customAtomInvariants=self.atomInvariants, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
from .baseclasses import FpsTransformer | ||
|
||
from rdkit.Avalon import pyAvalonTools | ||
|
||
|
||
class AvalonFingerprintTransformer(FpsTransformer): | ||
# Fingerprint from the Avalon toolkeit, https://doi.org/10.1021/ci050413p | ||
def __init__( | ||
self, | ||
fpSize: int = 512, | ||
isQuery: bool = False, | ||
resetVect: bool = False, | ||
bitFlags: int = 15761407, | ||
useCounts: bool = False, | ||
parallel: Union[bool, int] = False, | ||
safe_inference_mode: bool = False, | ||
dtype: np.dtype = np.int8, | ||
): | ||
"""Transform RDKit mols into Count or bit-based Avalon Fingerprints | ||
Parameters | ||
---------- | ||
fpSize : int, optional | ||
Size of the fingerprint, by default 512 | ||
isQuery : bool, optional | ||
use the fingerprint for a query structure, by default False | ||
resetVect : bool, optional | ||
reset vector, by default False NB: only used in GetAvalonFP (not for GetAvalonCountFP) | ||
bitFlags : int, optional | ||
Substructure fingerprint (32767) or similarity fingerprint (15761407) by default 15761407 | ||
useCounts : bool, optional | ||
If toggled will create the count and not bit-based fingerprint, by default False | ||
""" | ||
super().__init__( | ||
parallel=parallel, safe_inference_mode=safe_inference_mode, dtype=dtype | ||
) | ||
self.fpSize = fpSize | ||
self.isQuery = isQuery | ||
self.resetVect = resetVect | ||
self.bitFlags = bitFlags | ||
self.useCounts = useCounts | ||
|
||
def _mol2fp(self, mol): | ||
if self.useCounts: | ||
return pyAvalonTools.GetAvalonCountFP( | ||
mol, | ||
nBits=int(self.fpSize), | ||
isQuery=bool(self.isQuery), | ||
bitFlags=int(self.bitFlags), | ||
) | ||
else: | ||
return pyAvalonTools.GetAvalonFP( | ||
mol, | ||
nBits=int(self.fpSize), | ||
isQuery=bool(self.isQuery), | ||
resetVect=bool(self.resetVect), | ||
bitFlags=int(self.bitFlags), | ||
) |
Oops, something went wrong.